This is an automated email from the ASF dual-hosted git repository. kabhwan pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new fd9e5760bae [SPARK-40657] Add support for Java classes in Protobuf functions fd9e5760bae is described below commit fd9e5760bae847f47c9c108f0e58814748e0d9b1 Author: Raghu Angadi <raghu.ang...@databricks.com> AuthorDate: Fri Oct 21 15:46:50 2022 +0900 [SPARK-40657] Add support for Java classes in Protobuf functions ### What changes were proposed in this pull request? Adds support for compiled Java classes to Protobuf functions. This is tested with Protobuf v3 classes. V2 vs V3 issues will be handled in a separate PR. The main changes in this PR: - Changes to top level API: - Adds new version that takes just the class name. - Changes the order of arguments for existing API with descriptor files (`messageName` and `descFilePath` are swapped). - Protobuf utils methods to create descriptor from Java class name. - Many unit tests are update to check both versions : (1) with descriptor file and (2) with Java class name. - Maven build updates to generate Java classes to use in tests. - Miscellaneous changes: - Adds `proto` to package name in `proto` files used in tests. - A few TODO comments about improvements ### Why are the changes needed? Java compiled classes is a common method for users to provide Protobuf definitions. ### Does this PR introduce _any_ user-facing change? No. This updates interface, but for a new feature in active development. ### How was this patch tested? - Unit tests Closes #38286 from rangadi/protobuf-java. Authored-by: Raghu Angadi <raghu.ang...@databricks.com> Signed-off-by: Jungtaek Lim <kabhwan.opensou...@gmail.com> --- connector/protobuf/pom.xml | 23 +- .../sql/protobuf/CatalystDataToProtobuf.scala | 10 +- .../sql/protobuf/ProtobufDataToCatalyst.scala | 34 ++- .../org/apache/spark/sql/protobuf/functions.scala | 58 +++- .../spark/sql/protobuf/utils/ProtobufUtils.scala | 65 ++++- .../sql/protobuf/utils/SchemaConverters.scala | 4 + .../test/resources/protobuf/catalyst_types.proto | 4 +- .../test/resources/protobuf/functions_suite.proto | 4 +- .../src/test/resources/protobuf/serde_suite.proto | 6 +- .../ProtobufCatalystDataConversionSuite.scala | 97 +++++-- .../sql/protobuf/ProtobufFunctionsSuite.scala | 318 +++++++++++++-------- .../spark/sql/protobuf/ProtobufSerdeSuite.scala | 9 +- project/SparkBuild.scala | 6 +- python/pyspark/sql/protobuf/functions.py | 22 +- 14 files changed, 437 insertions(+), 223 deletions(-) diff --git a/connector/protobuf/pom.xml b/connector/protobuf/pom.xml index 0515f128b8d..b934c7f831a 100644 --- a/connector/protobuf/pom.xml +++ b/connector/protobuf/pom.xml @@ -83,7 +83,6 @@ <version>${protobuf.version}</version> <scope>compile</scope> </dependency> - </dependencies> <build> <outputDirectory>target/scala-${scala.binary.version}/classes</outputDirectory> @@ -110,6 +109,28 @@ </relocations> </configuration> </plugin> + <plugin> + <groupId>com.github.os72</groupId> + <artifactId>protoc-jar-maven-plugin</artifactId> + <version>3.11.4</version> + <!-- Generates Java classes for tests. TODO(Raghu): Generate descriptor files too. --> + <executions> + <execution> + <phase>generate-test-sources</phase> + <goals> + <goal>run</goal> + </goals> + <configuration> + <protocArtifact>com.google.protobuf:protoc:${protobuf.version}</protocArtifact> + <protocVersion>${protobuf.version}</protocVersion> + <inputDirectories> + <include>src/test/resources/protobuf</include> + </inputDirectories> + <addSources>test</addSources> + </configuration> + </execution> + </executions> + </plugin> </plugins> </build> </project> diff --git a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/CatalystDataToProtobuf.scala b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/CatalystDataToProtobuf.scala index 145100268c2..b9f7907ea8c 100644 --- a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/CatalystDataToProtobuf.scala +++ b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/CatalystDataToProtobuf.scala @@ -25,17 +25,17 @@ import org.apache.spark.sql.types.{BinaryType, DataType} private[protobuf] case class CatalystDataToProtobuf( child: Expression, - descFilePath: String, - messageName: String) + messageName: String, + descFilePath: Option[String] = None) extends UnaryExpression { override def dataType: DataType = BinaryType - @transient private lazy val protoType = - ProtobufUtils.buildDescriptor(descFilePath, messageName) + @transient private lazy val protoDescriptor = + ProtobufUtils.buildDescriptor(messageName, descFilePathOpt = descFilePath) @transient private lazy val serializer = - new ProtobufSerializer(child.dataType, protoType, child.nullable) + new ProtobufSerializer(child.dataType, protoDescriptor, child.nullable) override def nullSafeEval(input: Any): Any = { val dynamicMessage = serializer.serialize(input).asInstanceOf[DynamicMessage] diff --git a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/ProtobufDataToCatalyst.scala b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/ProtobufDataToCatalyst.scala index f08f8767997..cad2442f10c 100644 --- a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/ProtobufDataToCatalyst.scala +++ b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/ProtobufDataToCatalyst.scala @@ -31,9 +31,9 @@ import org.apache.spark.sql.types.{AbstractDataType, BinaryType, DataType, Struc private[protobuf] case class ProtobufDataToCatalyst( child: Expression, - descFilePath: String, messageName: String, - options: Map[String, String]) + descFilePath: Option[String] = None, + options: Map[String, String] = Map.empty) extends UnaryExpression with ExpectsInputTypes { @@ -55,10 +55,14 @@ private[protobuf] case class ProtobufDataToCatalyst( private lazy val protobufOptions = ProtobufOptions(options) @transient private lazy val messageDescriptor = - ProtobufUtils.buildDescriptor(descFilePath, messageName) + ProtobufUtils.buildDescriptor(messageName, descFilePath) + // TODO: Avoid carrying the file name. Read the contents of descriptor file only once + // at the start. Rest of the runs should reuse the buffer. Otherwise, it could + // cause inconsistencies if the file contents are changed the user after a few days. + // Same for the write side in [[CatalystDataToProtobuf]]. @transient private lazy val fieldsNumbers = - messageDescriptor.getFields.asScala.map(f => f.getNumber) + messageDescriptor.getFields.asScala.map(f => f.getNumber).toSet @transient private lazy val deserializer = new ProtobufDeserializer(messageDescriptor, dataType) @@ -108,18 +112,18 @@ private[protobuf] case class ProtobufDataToCatalyst( val binary = input.asInstanceOf[Array[Byte]] try { result = DynamicMessage.parseFrom(messageDescriptor, binary) - val unknownFields = result.getUnknownFields - if (!unknownFields.asMap().isEmpty) { - unknownFields.asMap().keySet().asScala.map { number => - { - if (fieldsNumbers.contains(number)) { - return handleException( - new Throwable(s"Type mismatch encountered for field:" + - s" ${messageDescriptor.getFields.get(number)}")) - } - } - } + // If the Java class is available, it is likely more efficient to parse with it than using + // DynamicMessage. Can consider it in the future if parsing overhead is noticeable. + + result.getUnknownFields.asMap().keySet().asScala.find(fieldsNumbers.contains(_)) match { + case Some(number) => + // Unknown fields contain a field with same number as a known field. Must be due to + // mismatch of schema between writer and reader here. + throw new IllegalArgumentException(s"Type mismatch encountered for field:" + + s" ${messageDescriptor.getFields.get(number)}") + case None => } + val deserialized = deserializer.deserialize(result) assert( deserialized.isDefined, diff --git a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/functions.scala b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/functions.scala index 283d1ca8c41..af30de40dad 100644 --- a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/functions.scala +++ b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/functions.scala @@ -33,20 +33,21 @@ object functions { * * @param data * the binary column. - * @param descFilePath - * the protobuf descriptor in Message GeneratedMessageV3 format. * @param messageName * the protobuf message name to look for in descriptorFile. + * @param descFilePath + * the protobuf descriptor in Message GeneratedMessageV3 format. * @since 3.4.0 */ @Experimental def from_protobuf( data: Column, - descFilePath: String, messageName: String, + descFilePath: String, options: java.util.Map[String, String]): Column = { new Column( - ProtobufDataToCatalyst(data.expr, descFilePath, messageName, options.asScala.toMap)) + ProtobufDataToCatalyst(data.expr, messageName, Some(descFilePath), options.asScala.toMap) + ) } /** @@ -57,15 +58,34 @@ object functions { * * @param data * the binary column. - * @param descFilePath - * the protobuf descriptor in Message GeneratedMessageV3 format. * @param messageName * the protobuf MessageName to look for in descriptorFile. + * @param descFilePath + * the protobuf descriptor in Message GeneratedMessageV3 format. * @since 3.4.0 */ @Experimental - def from_protobuf(data: Column, descFilePath: String, messageName: String): Column = { - new Column(ProtobufDataToCatalyst(data.expr, descFilePath, messageName, Map.empty)) + def from_protobuf(data: Column, messageName: String, descFilePath: String): Column = { + new Column(ProtobufDataToCatalyst(data.expr, messageName, descFilePath = Some(descFilePath))) + // TODO: Add an option for user to provide descriptor file content as a buffer. This + // gives flexibility in how the content is fetched. + } + + /** + * Converts a binary column of Protobuf format into its corresponding catalyst value. The + * specified schema must match actual schema of the read data, otherwise the behavior is + * undefined: it may fail or return arbitrary result. To deserialize the data with a compatible + * and evolved schema, the expected Protobuf schema can be set via the option protoSchema. + * + * @param data + * the binary column. + * @param messageClassName + * The Protobuf class name. E.g. <code>org.spark.examples.protobuf.ExampleEvent</code>. + * @since 3.4.0 + */ + @Experimental + def from_protobuf(data: Column, messageClassName: String): Column = { + new Column(ProtobufDataToCatalyst(data.expr, messageClassName)) } /** @@ -73,14 +93,28 @@ object functions { * * @param data * the data column. - * @param descFilePath - * the protobuf descriptor in Message GeneratedMessageV3 format. * @param messageName * the protobuf MessageName to look for in descriptorFile. + * @param descFilePath + * the protobuf descriptor in Message GeneratedMessageV3 format. + * @since 3.4.0 + */ + @Experimental + def to_protobuf(data: Column, messageName: String, descFilePath: String): Column = { + new Column(CatalystDataToProtobuf(data.expr, messageName, Some(descFilePath))) + } + + /** + * Converts a column into binary of protobuf format. + * + * @param data + * the data column. + * @param messageClassName + * The Protobuf class name. E.g. <code>org.spark.examples.protobuf.ExampleEvent</code>. * @since 3.4.0 */ @Experimental - def to_protobuf(data: Column, descFilePath: String, messageName: String): Column = { - new Column(CatalystDataToProtobuf(data.expr, descFilePath, messageName)) + def to_protobuf(data: Column, messageClassName: String): Column = { + new Column(CatalystDataToProtobuf(data.expr, messageClassName)) } } diff --git a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/ProtobufUtils.scala b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/ProtobufUtils.scala index 5ad043142a2..fa2ec9b7cd4 100644 --- a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/ProtobufUtils.scala +++ b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/ProtobufUtils.scala @@ -22,13 +22,14 @@ import java.util.Locale import scala.collection.JavaConverters._ -import com.google.protobuf.{DescriptorProtos, Descriptors, InvalidProtocolBufferException} +import com.google.protobuf.{DescriptorProtos, Descriptors, InvalidProtocolBufferException, Message} import com.google.protobuf.Descriptors.{Descriptor, FieldDescriptor} import org.apache.spark.internal.Logging import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.protobuf.utils.SchemaConverters.IncompatibleSchemaException import org.apache.spark.sql.types._ +import org.apache.spark.util.Utils private[sql] object ProtobufUtils extends Logging { @@ -132,23 +133,63 @@ private[sql] object ProtobufUtils extends Logging { } } - def buildDescriptor(descFilePath: String, messageName: String): Descriptor = { - val fileDescriptor: Descriptors.FileDescriptor = parseFileDescriptor(descFilePath) - var result: Descriptors.Descriptor = null; + /** + * Builds Protobuf message descriptor either from the Java class or from serialized descriptor + * read from the file. + * @param messageName + * Protobuf message name or Java class name. + * @param descFilePathOpt + * When the file name set, the descriptor and it's dependencies are read from the file. Other + * the `messageName` is treated as Java class name. + * @return + */ + def buildDescriptor(messageName: String, descFilePathOpt: Option[String]): Descriptor = { + descFilePathOpt match { + case Some(filePath) => buildDescriptor(descFilePath = filePath, messageName) + case None => buildDescriptorFromJavaClass(messageName) + } + } - for (descriptor <- fileDescriptor.getMessageTypes.asScala) { - if (descriptor.getName().equals(messageName)) { - result = descriptor - } + /** + * Loads the given protobuf class and returns Protobuf descriptor for it. + */ + def buildDescriptorFromJavaClass(protobufClassName: String): Descriptor = { + val protobufClass = try { + Utils.classForName(protobufClassName) + } catch { + case _: ClassNotFoundException => + val hasDots = protobufClassName.contains(".") + throw new IllegalArgumentException( + s"Could not load Protobuf class with name '$protobufClassName'" + + (if (hasDots) "" else ". Ensure the class name includes package prefix.") + ) + } + + if (!classOf[Message].isAssignableFrom(protobufClass)) { + throw new IllegalArgumentException(s"$protobufClassName is not a Protobuf message type") + // TODO: Need to support V2. This might work with V2 classes too. + } + + // Extract the descriptor from Protobuf message. + protobufClass + .getDeclaredMethod("getDescriptor") + .invoke(null) + .asInstanceOf[Descriptor] + } + + def buildDescriptor(descFilePath: String, messageName: String): Descriptor = { + val descriptor = parseFileDescriptor(descFilePath).getMessageTypes.asScala.find { desc => + desc.getName == messageName || desc.getFullName == messageName } - if (null == result) { - throw new RuntimeException("Unable to locate Message '" + messageName + "' in Descriptor"); + descriptor match { + case Some(d) => d + case None => + throw new RuntimeException(s"Unable to locate Message '$messageName' in Descriptor") } - result } - def parseFileDescriptor(descFilePath: String): Descriptors.FileDescriptor = { + private def parseFileDescriptor(descFilePath: String): Descriptors.FileDescriptor = { var fileDescriptorSet: DescriptorProtos.FileDescriptorSet = null try { val dscFile = new BufferedInputStream(new FileInputStream(descFilePath)) diff --git a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/SchemaConverters.scala b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/SchemaConverters.scala index e385b816abe..4fca06fb5d8 100644 --- a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/SchemaConverters.scala +++ b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/SchemaConverters.scala @@ -66,6 +66,10 @@ object SchemaConverters { Some(DayTimeIntervalType.defaultConcreteType) case MESSAGE if fd.getMessageType.getName == "Timestamp" => Some(TimestampType) + // FIXME: Is the above accurate? Users can have protos named "Timestamp" but are not + // expected to be TimestampType in Spark. How about verifying fields? + // Same for "Duration". Only the Timestamp & Duration protos defined in + // google.protobuf package should default to corresponding Catalylist types. case MESSAGE if fd.isRepeated && fd.getMessageType.getOptions.hasMapEntry => var keyType: DataType = NullType var valueType: DataType = NullType diff --git a/connector/protobuf/src/test/resources/protobuf/catalyst_types.proto b/connector/protobuf/src/test/resources/protobuf/catalyst_types.proto index 54e6bc18df1..1deb193438c 100644 --- a/connector/protobuf/src/test/resources/protobuf/catalyst_types.proto +++ b/connector/protobuf/src/test/resources/protobuf/catalyst_types.proto @@ -19,9 +19,11 @@ syntax = "proto3"; -package org.apache.spark.sql.protobuf; +package org.apache.spark.sql.protobuf.protos; option java_outer_classname = "CatalystTypes"; +// TODO: import one or more protobuf files. + message BooleanMsg { bool bool_type = 1; } diff --git a/connector/protobuf/src/test/resources/protobuf/functions_suite.proto b/connector/protobuf/src/test/resources/protobuf/functions_suite.proto index f38c041b799..60f8c262141 100644 --- a/connector/protobuf/src/test/resources/protobuf/functions_suite.proto +++ b/connector/protobuf/src/test/resources/protobuf/functions_suite.proto @@ -20,7 +20,7 @@ syntax = "proto3"; -package org.apache.spark.sql.protobuf; +package org.apache.spark.sql.protobuf.protos; option java_outer_classname = "SimpleMessageProtos"; @@ -119,7 +119,7 @@ message SimpleMessageEnum { string key = 1; string value = 2; enum NestedEnum { - ESTED_NOTHING = 0; + ESTED_NOTHING = 0; // TODO: Fix the name. NESTED_FIRST = 1; NESTED_SECOND = 2; } diff --git a/connector/protobuf/src/test/resources/protobuf/serde_suite.proto b/connector/protobuf/src/test/resources/protobuf/serde_suite.proto index 1e3065259aa..a7459213a87 100644 --- a/connector/protobuf/src/test/resources/protobuf/serde_suite.proto +++ b/connector/protobuf/src/test/resources/protobuf/serde_suite.proto @@ -20,11 +20,11 @@ syntax = "proto3"; -package org.apache.spark.sql.protobuf; -option java_outer_classname = "SimpleMessageProtos"; +package org.apache.spark.sql.protobuf.protos; +option java_outer_classname = "SerdeSuiteProtos"; /* Clean Message*/ -message BasicMessage { +message SerdeBasicMessage { Foo foo = 1; } diff --git a/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufCatalystDataConversionSuite.scala b/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufCatalystDataConversionSuite.scala index b730ebb4fea..19774a2ad07 100644 --- a/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufCatalystDataConversionSuite.scala +++ b/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufCatalystDataConversionSuite.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.{RandomDataGenerator, Row} import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow, NoopFilters, OrderedFilters, StructFilters} import org.apache.spark.sql.catalyst.expressions.{ExpressionEvalHelper, GenericInternalRow, Literal} import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData, MapData} +import org.apache.spark.sql.protobuf.protos.CatalystTypes.BytesMsg import org.apache.spark.sql.protobuf.utils.{ProtobufUtils, SchemaConverters} import org.apache.spark.sql.sources.{EqualTo, Not} import org.apache.spark.sql.test.SharedSparkSession @@ -35,18 +36,32 @@ class ProtobufCatalystDataConversionSuite with SharedSparkSession with ExpressionEvalHelper { - private def checkResult( + private val testFileDesc = testFile("protobuf/catalyst_types.desc").replace("file:/", "/") + private val javaClassNamePrefix = "org.apache.spark.sql.protobuf.protos.CatalystTypes$" + + private def checkResultWithEval( data: Literal, descFilePath: String, messageName: String, expected: Any): Unit = { - checkEvaluation( - ProtobufDataToCatalyst( - CatalystDataToProtobuf(data, descFilePath, messageName), - descFilePath, - messageName, - Map.empty), - prepareExpectedResult(expected)) + + withClue("(Eval check with Java class name)") { + val className = s"$javaClassNamePrefix$messageName" + checkEvaluation( + ProtobufDataToCatalyst( + CatalystDataToProtobuf(data, className), + className, + descFilePath = None), + prepareExpectedResult(expected)) + } + withClue("(Eval check with descriptor file)") { + checkEvaluation( + ProtobufDataToCatalyst( + CatalystDataToProtobuf(data, messageName, Some(descFilePath)), + messageName, + descFilePath = Some(descFilePath)), + prepareExpectedResult(expected)) + } } protected def checkUnsupportedRead( @@ -55,10 +70,11 @@ class ProtobufCatalystDataConversionSuite actualSchema: String, badSchema: String): Unit = { - val binary = CatalystDataToProtobuf(data, descFilePath, actualSchema) + val binary = CatalystDataToProtobuf(data, actualSchema, Some(descFilePath)) intercept[Exception] { - ProtobufDataToCatalyst(binary, descFilePath, badSchema, Map("mode" -> "FAILFAST")).eval() + ProtobufDataToCatalyst(binary, badSchema, Some(descFilePath), Map("mode" -> "FAILFAST")) + .eval() } val expected = { @@ -73,7 +89,7 @@ class ProtobufCatalystDataConversionSuite } checkEvaluation( - ProtobufDataToCatalyst(binary, descFilePath, badSchema, Map("mode" -> "PERMISSIVE")), + ProtobufDataToCatalyst(binary, badSchema, Some(descFilePath), Map("mode" -> "PERMISSIVE")), expected) } @@ -99,26 +115,32 @@ class ProtobufCatalystDataConversionSuite StructType(StructField("bytes_type", BinaryType, nullable = true) :: Nil), StructType(StructField("string_type", StringType, nullable = true) :: Nil)) - private val catalystTypesToProtoMessages: Map[DataType, String] = Map( - IntegerType -> "IntegerMsg", - DoubleType -> "DoubleMsg", - FloatType -> "FloatMsg", - BinaryType -> "BytesMsg", - StringType -> "StringMsg") + private val catalystTypesToProtoMessages: Map[DataType, (String, Any)] = Map( + IntegerType -> ("IntegerMsg", 0), + DoubleType -> ("DoubleMsg", 0.0d), + FloatType -> ("FloatMsg", 0.0f), + BinaryType -> ("BytesMsg", ByteString.empty().toByteArray), + StringType -> ("StringMsg", "")) testingTypes.foreach { dt => val seed = 1 + scala.util.Random.nextInt((1024 - 1) + 1) - val filePath = testFile("protobuf/catalyst_types.desc").replace("file:/", "/") test(s"single $dt with seed $seed") { + + val (messageName, defaultValue) = catalystTypesToProtoMessages(dt.fields(0).dataType) + val rand = new scala.util.Random(seed) - val data = RandomDataGenerator.forType(dt, rand = rand).get.apply() + val generator = RandomDataGenerator.forType(dt, rand = rand).get + var data = generator() + while (data.asInstanceOf[Row].get(0) == defaultValue) // Do not use default values, since + data = generator() // from_protobuf() returns null in v3. + val converter = CatalystTypeConverters.createToCatalystConverter(dt) val input = Literal.create(converter(data), dt) - checkResult( + checkResultWithEval( input, - filePath, - catalystTypesToProtoMessages(dt.fields(0).dataType), + testFileDesc, + messageName, input.eval()) } } @@ -137,6 +159,15 @@ class ProtobufCatalystDataConversionSuite val dynMsg = DynamicMessage.parseFrom(descriptor, data.toByteArray) val deserialized = deserializer.deserialize(dynMsg) + + // Verify Java class deserializer matches with descriptor based serializer. + val javaDescriptor = ProtobufUtils + .buildDescriptorFromJavaClass(s"$javaClassNamePrefix$messageName") + assert(dataType == SchemaConverters.toSqlType(javaDescriptor).dataType) + val javaDeserialized = new ProtobufDeserializer(javaDescriptor, dataType, filters) + .deserialize(DynamicMessage.parseFrom(javaDescriptor, data.toByteArray)) + assert(deserialized == javaDeserialized) + expected match { case None => assert(deserialized.isEmpty) case Some(d) => @@ -145,7 +176,6 @@ class ProtobufCatalystDataConversionSuite } test("Handle unsupported input of message type") { - val testFileDesc = testFile("protobuf/catalyst_types.desc").replace("file:/", "/") val actualSchema = StructType( Seq( StructField("col_0", StringType, nullable = false), @@ -165,7 +195,6 @@ class ProtobufCatalystDataConversionSuite test("filter push-down to Protobuf deserializer") { - val testFileDesc = testFile("protobuf/catalyst_types.desc").replace("file:/", "/") val sqlSchema = new StructType() .add("name", "string") .add("age", "int") @@ -196,17 +225,23 @@ class ProtobufCatalystDataConversionSuite test("ProtobufDeserializer with binary type") { - val testFileDesc = testFile("protobuf/catalyst_types.desc").replace("file:/", "/") val bb = java.nio.ByteBuffer.wrap(Array[Byte](97, 48, 53)) - val descriptor = ProtobufUtils.buildDescriptor(testFileDesc, "BytesMsg") - - val dynamicMessage = DynamicMessage - .newBuilder(descriptor) - .setField(descriptor.findFieldByName("bytes_type"), ByteString.copyFrom(bb)) + val bytesProto = BytesMsg + .newBuilder() + .setBytesType(ByteString.copyFrom(bb)) .build() val expected = InternalRow(Array[Byte](97, 48, 53)) - checkDeserialization(testFileDesc, "BytesMsg", dynamicMessage, Some(expected)) + checkDeserialization(testFileDesc, "BytesMsg", bytesProto, Some(expected)) + } + + test("Full names for message using descriptor file") { + val withShortName = ProtobufUtils.buildDescriptor(testFileDesc, "BytesMsg") + assert(withShortName.findFieldByName("bytes_type") != null) + + val withFullName = ProtobufUtils.buildDescriptor( + testFileDesc, "org.apache.spark.sql.protobuf.BytesMsg") + assert(withFullName.findFieldByName("bytes_type") != null) } } diff --git a/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufFunctionsSuite.scala b/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufFunctionsSuite.scala index 4e9bc1c1c28..72280fb0d9e 100644 --- a/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufFunctionsSuite.scala +++ b/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufFunctionsSuite.scala @@ -23,8 +23,10 @@ import scala.collection.JavaConverters._ import com.google.protobuf.{ByteString, DynamicMessage} -import org.apache.spark.sql.{QueryTest, Row} +import org.apache.spark.sql.{Column, QueryTest, Row} import org.apache.spark.sql.functions.{lit, struct} +import org.apache.spark.sql.protobuf.protos.SimpleMessageProtos.SimpleMessageRepeated +import org.apache.spark.sql.protobuf.protos.SimpleMessageProtos.SimpleMessageRepeated.NestedEnum import org.apache.spark.sql.protobuf.utils.ProtobufUtils import org.apache.spark.sql.protobuf.utils.SchemaConverters.IncompatibleSchemaException import org.apache.spark.sql.test.SharedSparkSession @@ -35,6 +37,39 @@ class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with Seri import testImplicits._ val testFileDesc = testFile("protobuf/functions_suite.desc").replace("file:/", "/") + private val javaClassNamePrefix = "org.apache.spark.sql.protobuf.protos.SimpleMessageProtos$" + + /** + * Runs the given closure twice. Once with descriptor file and second time with Java class name. + */ + private def checkWithFileAndClassName(messageName: String)( + fn: (String, Option[String]) => Unit): Unit = { + withClue("(With descriptor file)") { + fn(messageName, Some(testFileDesc)) + } + withClue("(With Java class name)") { + fn(s"$javaClassNamePrefix$messageName", None) + } + } + + // A wrapper to invoke the right variable of from_protobuf() depending on arguments. + private def from_protobuf_wrapper( + col: Column, messageName: String, descFilePathOpt: Option[String]): Column = { + descFilePathOpt match { + case Some(descFilePath) => functions.from_protobuf(col, messageName, descFilePath) + case None => functions.from_protobuf(col, messageName) + } + } + + // A wrapper to invoke the right variable of to_protobuf() depending on arguments. + private def to_protobuf_wrapper( + col: Column, messageName: String, descFilePathOpt: Option[String]): Column = { + descFilePathOpt match { + case Some(descFilePath) => functions.to_protobuf(col, messageName, descFilePath) + case None => functions.to_protobuf(col, messageName) + } + } + test("roundtrip in to_protobuf and from_protobuf - struct") { val df = spark @@ -56,44 +91,45 @@ class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with Seri lit(1202.00).cast(org.apache.spark.sql.types.FloatType).as("float_value"), lit(true).as("bool_value"), lit("0".getBytes).as("bytes_value")).as("SimpleMessage")) - val protoStructDF = df.select( - functions.to_protobuf($"SimpleMessage", testFileDesc, "SimpleMessage").as("proto")) - val actualDf = protoStructDF.select( - functions.from_protobuf($"proto", testFileDesc, "SimpleMessage").as("proto.*")) - checkAnswer(actualDf, df) + + checkWithFileAndClassName("SimpleMessage") { + case (name, descFilePathOpt) => + val protoStructDF = df.select( + to_protobuf_wrapper($"SimpleMessage", name, descFilePathOpt).as("proto")) + val actualDf = protoStructDF.select( + from_protobuf_wrapper($"proto", name, descFilePathOpt).as("proto.*")) + checkAnswer(actualDf, df) + } } test("roundtrip in from_protobuf and to_protobuf - Repeated") { - val descriptor = ProtobufUtils.buildDescriptor(testFileDesc, "SimpleMessageRepeated") - val dynamicMessage = DynamicMessage - .newBuilder(descriptor) - .setField(descriptor.findFieldByName("key"), "key") - .setField(descriptor.findFieldByName("value"), "value") - .addRepeatedField(descriptor.findFieldByName("rbool_value"), false) - .addRepeatedField(descriptor.findFieldByName("rbool_value"), true) - .addRepeatedField(descriptor.findFieldByName("rdouble_value"), 1092092.654d) - .addRepeatedField(descriptor.findFieldByName("rdouble_value"), 1092093.654d) - .addRepeatedField(descriptor.findFieldByName("rfloat_value"), 10903.0f) - .addRepeatedField(descriptor.findFieldByName("rfloat_value"), 10902.0f) - .addRepeatedField( - descriptor.findFieldByName("rnested_enum"), - descriptor.findEnumTypeByName("NestedEnum").findValueByName("ESTED_NOTHING")) - .addRepeatedField( - descriptor.findFieldByName("rnested_enum"), - descriptor.findEnumTypeByName("NestedEnum").findValueByName("NESTED_FIRST")) + val protoMessage = SimpleMessageRepeated + .newBuilder() + .setKey("key") + .setValue("value") + .addRboolValue(false) + .addRboolValue(true) + .addRdoubleValue(1092092.654d) + .addRdoubleValue(1092093.654d) + .addRfloatValue(10903.0f) + .addRfloatValue(10902.0f) + .addRnestedEnum(NestedEnum.ESTED_NOTHING) + .addRnestedEnum(NestedEnum.NESTED_FIRST) .build() - val df = Seq(dynamicMessage.toByteArray).toDF("value") - val fromProtoDF = df.select( - functions.from_protobuf($"value", testFileDesc, "SimpleMessageRepeated").as("value_from")) - val toProtoDF = fromProtoDF.select( - functions.to_protobuf($"value_from", testFileDesc, "SimpleMessageRepeated").as("value_to")) - val toFromProtoDF = toProtoDF.select( - functions - .from_protobuf($"value_to", testFileDesc, "SimpleMessageRepeated") - .as("value_to_from")) - checkAnswer(fromProtoDF.select($"value_from.*"), toFromProtoDF.select($"value_to_from.*")) + val df = Seq(protoMessage.toByteArray).toDF("value") + + checkWithFileAndClassName("SimpleMessageRepeated") { + case (name, descFilePathOpt) => + val fromProtoDF = df.select( + from_protobuf_wrapper($"value", name, descFilePathOpt).as("value_from")) + val toProtoDF = fromProtoDF.select( + to_protobuf_wrapper($"value_from", name, descFilePathOpt).as("value_to")) + val toFromProtoDF = toProtoDF.select( + from_protobuf_wrapper($"value_to", name, descFilePathOpt).as("value_to_from")) + checkAnswer(fromProtoDF.select($"value_from.*"), toFromProtoDF.select($"value_to_from.*")) + } } test("roundtrip in from_protobuf and to_protobuf - Repeated Message Once") { @@ -120,13 +156,17 @@ class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with Seri .build() val df = Seq(dynamicMessage.toByteArray).toDF("value") - val fromProtoDF = df.select( - functions.from_protobuf($"value", testFileDesc, "RepeatedMessage").as("value_from")) - val toProtoDF = fromProtoDF.select( - functions.to_protobuf($"value_from", testFileDesc, "RepeatedMessage").as("value_to")) - val toFromProtoDF = toProtoDF.select( - functions.from_protobuf($"value_to", testFileDesc, "RepeatedMessage").as("value_to_from")) - checkAnswer(fromProtoDF.select($"value_from.*"), toFromProtoDF.select($"value_to_from.*")) + + checkWithFileAndClassName("RepeatedMessage") { + case (name, descFilePathOpt) => + val fromProtoDF = df.select( + from_protobuf_wrapper($"value", name, descFilePathOpt).as("value_from")) + val toProtoDF = fromProtoDF.select( + to_protobuf_wrapper($"value_from", name, descFilePathOpt).as("value_to")) + val toFromProtoDF = toProtoDF.select( + from_protobuf_wrapper($"value_to", name, descFilePathOpt).as("value_to_from")) + checkAnswer(fromProtoDF.select($"value_from.*"), toFromProtoDF.select($"value_to_from.*")) + } } test("roundtrip in from_protobuf and to_protobuf - Repeated Message Twice") { @@ -167,13 +207,17 @@ class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with Seri .build() val df = Seq(dynamicMessage.toByteArray).toDF("value") - val fromProtoDF = df.select( - functions.from_protobuf($"value", testFileDesc, "RepeatedMessage").as("value_from")) - val toProtoDF = fromProtoDF.select( - functions.to_protobuf($"value_from", testFileDesc, "RepeatedMessage").as("value_to")) - val toFromProtoDF = toProtoDF.select( - functions.from_protobuf($"value_to", testFileDesc, "RepeatedMessage").as("value_to_from")) - checkAnswer(fromProtoDF.select($"value_from.*"), toFromProtoDF.select($"value_to_from.*")) + + checkWithFileAndClassName("RepeatedMessage") { + case (name, descFilePathOpt) => + val fromProtoDF = df.select( + from_protobuf_wrapper($"value", name, descFilePathOpt).as("value_from")) + val toProtoDF = fromProtoDF.select( + to_protobuf_wrapper($"value_from", name, descFilePathOpt).as("value_to")) + val toFromProtoDF = toProtoDF.select( + from_protobuf_wrapper($"value_to", name, descFilePathOpt).as("value_to_from")) + checkAnswer(fromProtoDF.select($"value_from.*"), toFromProtoDF.select($"value_to_from.*")) + } } test("roundtrip in from_protobuf and to_protobuf - Map") { @@ -257,13 +301,17 @@ class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with Seri .build() val df = Seq(dynamicMessage.toByteArray).toDF("value") - val fromProtoDF = df.select( - functions.from_protobuf($"value", testFileDesc, "SimpleMessageMap").as("value_from")) - val toProtoDF = fromProtoDF.select( - functions.to_protobuf($"value_from", testFileDesc, "SimpleMessageMap").as("value_to")) - val toFromProtoDF = toProtoDF.select( - functions.from_protobuf($"value_to", testFileDesc, "SimpleMessageMap").as("value_to_from")) - checkAnswer(fromProtoDF.select($"value_from.*"), toFromProtoDF.select($"value_to_from.*")) + + checkWithFileAndClassName("SimpleMessageMap") { + case (name, descFilePathOpt) => + val fromProtoDF = df.select( + from_protobuf_wrapper($"value", name, descFilePathOpt).as("value_from")) + val toProtoDF = fromProtoDF.select( + to_protobuf_wrapper($"value_from", name, descFilePathOpt).as("value_to")) + val toFromProtoDF = toProtoDF.select( + from_protobuf_wrapper($"value_to", name, descFilePathOpt).as("value_to_from")) + checkAnswer(fromProtoDF.select($"value_from.*"), toFromProtoDF.select($"value_to_from.*")) + } } test("roundtrip in from_protobuf and to_protobuf - Enum") { @@ -289,13 +337,17 @@ class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with Seri .build() val df = Seq(dynamicMessage.toByteArray).toDF("value") - val fromProtoDF = df.select( - functions.from_protobuf($"value", testFileDesc, "SimpleMessageEnum").as("value_from")) - val toProtoDF = fromProtoDF.select( - functions.to_protobuf($"value_from", testFileDesc, "SimpleMessageEnum").as("value_to")) - val toFromProtoDF = toProtoDF.select( - functions.from_protobuf($"value_to", testFileDesc, "SimpleMessageEnum").as("value_to_from")) - checkAnswer(fromProtoDF.select($"value_from.*"), toFromProtoDF.select($"value_to_from.*")) + + checkWithFileAndClassName("SimpleMessageEnum") { + case (name, descFilePathOpt) => + val fromProtoDF = df.select( + from_protobuf_wrapper($"value", name, descFilePathOpt).as("value_from")) + val toProtoDF = fromProtoDF.select( + to_protobuf_wrapper($"value_from", name, descFilePathOpt).as("value_to")) + val toFromProtoDF = toProtoDF.select( + from_protobuf_wrapper($"value_to", name, descFilePathOpt).as("value_to_from")) + checkAnswer(fromProtoDF.select($"value_from.*"), toFromProtoDF.select($"value_to_from.*")) + } } test("roundtrip in from_protobuf and to_protobuf - Multiple Message") { @@ -320,13 +372,17 @@ class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with Seri .build() val df = Seq(dynamicMessage.toByteArray).toDF("value") - val fromProtoDF = df.select( - functions.from_protobuf($"value", testFileDesc, "MultipleExample").as("value_from")) - val toProtoDF = fromProtoDF.select( - functions.to_protobuf($"value_from", testFileDesc, "MultipleExample").as("value_to")) - val toFromProtoDF = toProtoDF.select( - functions.from_protobuf($"value_to", testFileDesc, "MultipleExample").as("value_to_from")) - checkAnswer(fromProtoDF.select($"value_from.*"), toFromProtoDF.select($"value_to_from.*")) + + checkWithFileAndClassName("MultipleExample") { + case (name, descFilePathOpt) => + val fromProtoDF = df.select( + from_protobuf_wrapper($"value", name, descFilePathOpt).as("value_from")) + val toProtoDF = fromProtoDF.select( + to_protobuf_wrapper($"value_from", name, descFilePathOpt).as("value_to")) + val toFromProtoDF = toProtoDF.select( + from_protobuf_wrapper($"value_to", name, descFilePathOpt).as("value_to_from")) + checkAnswer(fromProtoDF.select($"value_from.*"), toFromProtoDF.select($"value_to_from.*")) + } } test("Handle recursive fields in Protobuf schema, A->B->A") { @@ -352,15 +408,17 @@ class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with Seri val df = Seq(messageB.toByteArray).toDF("messageB") - val e = intercept[IncompatibleSchemaException] { - df.select( - functions.from_protobuf($"messageB", testFileDesc, "recursiveB").as("messageFromProto")) - .show() + checkWithFileAndClassName("recursiveB") { + case (name, descFilePathOpt) => + val e = intercept[IncompatibleSchemaException] { + df.select( + from_protobuf_wrapper($"messageB", name, descFilePathOpt).as("messageFromProto")) + .show() + } + assert(e.getMessage.contains( + "Found recursive reference in Protobuf schema, which can not be processed by Spark:" + )) } - val expectedMessage = s""" - |Found recursive reference in Protobuf schema, which can not be processed by Spark: - |org.apache.spark.sql.protobuf.recursiveB.messageA""".stripMargin - assert(e.getMessage == expectedMessage) } test("Handle recursive fields in Protobuf schema, C->D->Array(C)") { @@ -386,16 +444,17 @@ class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with Seri val df = Seq(messageD.toByteArray).toDF("messageD") - val e = intercept[IncompatibleSchemaException] { - df.select( - functions.from_protobuf($"messageD", testFileDesc, "recursiveD").as("messageFromProto")) - .show() + checkWithFileAndClassName("recursiveD") { + case (name, descFilePathOpt) => + val e = intercept[IncompatibleSchemaException] { + df.select( + from_protobuf_wrapper($"messageD", name, descFilePathOpt).as("messageFromProto")) + .show() + } + assert(e.getMessage.contains( + "Found recursive reference in Protobuf schema, which can not be processed by Spark:" + )) } - val expectedMessage = - s""" - |Found recursive reference in Protobuf schema, which can not be processed by Spark: - |org.apache.spark.sql.protobuf.recursiveD.messageC""".stripMargin - assert(e.getMessage == expectedMessage) } test("Handle extra fields : oldProducer -> newConsumer") { @@ -411,17 +470,17 @@ class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with Seri val df = Seq(oldProducerMessage.toByteArray).toDF("oldProducerData") val fromProtoDf = df.select( functions - .from_protobuf($"oldProducerData", testFileDesc, "newConsumer") + .from_protobuf($"oldProducerData", "newConsumer", testFileDesc) .as("fromProto")) val toProtoDf = fromProtoDf.select( functions - .to_protobuf($"fromProto", testFileDesc, "newConsumer") + .to_protobuf($"fromProto", "newConsumer", testFileDesc) .as("toProto")) val toProtoDfToFromProtoDf = toProtoDf.select( functions - .from_protobuf($"toProto", testFileDesc, "newConsumer") + .from_protobuf($"toProto", "newConsumer", testFileDesc) .as("toProtoToFromProto")) val actualFieldNames = @@ -452,7 +511,7 @@ class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with Seri val df = Seq(newProducerMessage.toByteArray).toDF("newProducerData") val fromProtoDf = df.select( functions - .from_protobuf($"newProducerData", testFileDesc, "oldConsumer") + .from_protobuf($"newProducerData", "oldConsumer", testFileDesc) .as("oldConsumerProto")) val expectedFieldNames = oldConsumer.getFields.asScala.map(f => f.getName) @@ -481,8 +540,9 @@ class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with Seri )), schema ) + val toProtobuf = inputDf.select( - functions.to_protobuf($"requiredMsg", testFileDesc, "requiredMsg") + functions.to_protobuf($"requiredMsg", "requiredMsg", testFileDesc) .as("to_proto")) val binary = toProtobuf.take(1).toSeq(0).get(0).asInstanceOf[Array[Byte]] @@ -498,7 +558,7 @@ class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with Seri assert(actualMessage.getField(messageDescriptor.findFieldByName("col_3")) == 0) val fromProtoDf = toProtobuf.select( - functions.from_protobuf($"to_proto", testFileDesc, "requiredMsg") as 'from_proto) + functions.from_protobuf($"to_proto", "requiredMsg", testFileDesc) as 'from_proto) assert(fromProtoDf.select("from_proto.key").take(1).toSeq(0).get(0) == inputDf.select("requiredMsg.key").take(1).toSeq(0).get(0)) @@ -526,16 +586,20 @@ class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with Seri .build() val df = Seq(basicMessage.toByteArray).toDF("value") - val resultFrom = df - .select(functions.from_protobuf($"value", testFileDesc, "BasicMessage") as 'sample) - .where("sample.string_value == \"slam\"") - val resultToFrom = resultFrom - .select(functions.to_protobuf($"sample", testFileDesc, "BasicMessage") as 'value) - .select(functions.from_protobuf($"value", testFileDesc, "BasicMessage") as 'sample) - .where("sample.string_value == \"slam\"") + checkWithFileAndClassName("BasicMessage") { + case (name, descFilePathOpt) => + val resultFrom = df + .select(from_protobuf_wrapper($"value", name, descFilePathOpt) as 'sample) + .where("sample.string_value == \"slam\"") + + val resultToFrom = resultFrom + .select(to_protobuf_wrapper($"sample", name, descFilePathOpt) as 'value) + .select(from_protobuf_wrapper($"value", name, descFilePathOpt) as 'sample) + .where("sample.string_value == \"slam\"") - assert(resultFrom.except(resultToFrom).isEmpty) + assert(resultFrom.except(resultToFrom).isEmpty) + } } test("Handle TimestampType between to_protobuf and from_protobuf") { @@ -556,22 +620,24 @@ class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with Seri schema ) - val toProtoDf = inputDf - .select(functions.to_protobuf($"timeStampMsg", testFileDesc, "timeStampMsg") as 'to_proto) + checkWithFileAndClassName("timeStampMsg") { + case (name, descFilePathOpt) => + val toProtoDf = inputDf + .select(to_protobuf_wrapper($"timeStampMsg", name, descFilePathOpt) as 'to_proto) - val fromProtoDf = toProtoDf - .select(functions.from_protobuf($"to_proto", testFileDesc, "timeStampMsg") as 'timeStampMsg) - fromProtoDf.show(truncate = false) + val fromProtoDf = toProtoDf + .select(from_protobuf_wrapper($"to_proto", name, descFilePathOpt) as 'timeStampMsg) - val actualFields = fromProtoDf.schema.fields.toList - val expectedFields = inputDf.schema.fields.toList + val actualFields = fromProtoDf.schema.fields.toList + val expectedFields = inputDf.schema.fields.toList - assert(actualFields.size === expectedFields.size) - assert(actualFields === expectedFields) - assert(fromProtoDf.select("timeStampMsg.key").take(1).toSeq(0).get(0) - === inputDf.select("timeStampMsg.key").take(1).toSeq(0).get(0)) - assert(fromProtoDf.select("timeStampMsg.stmp").take(1).toSeq(0).get(0) - === inputDf.select("timeStampMsg.stmp").take(1).toSeq(0).get(0)) + assert(actualFields.size === expectedFields.size) + assert(actualFields === expectedFields) + assert(fromProtoDf.select("timeStampMsg.key").take(1).toSeq(0).get(0) + === inputDf.select("timeStampMsg.key").take(1).toSeq(0).get(0)) + assert(fromProtoDf.select("timeStampMsg.stmp").take(1).toSeq(0).get(0) + === inputDf.select("timeStampMsg.stmp").take(1).toSeq(0).get(0)) + } } test("Handle DayTimeIntervalType between to_protobuf and from_protobuf") { @@ -595,21 +661,23 @@ class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with Seri schema ) - val toProtoDf = inputDf - .select(functions.to_protobuf($"durationMsg", testFileDesc, "durationMsg") as 'to_proto) + checkWithFileAndClassName("durationMsg") { + case (name, descFilePathOpt) => + val toProtoDf = inputDf + .select(to_protobuf_wrapper($"durationMsg", name, descFilePathOpt) as 'to_proto) - val fromProtoDf = toProtoDf - .select(functions.from_protobuf($"to_proto", testFileDesc, "durationMsg") as 'durationMsg) + val fromProtoDf = toProtoDf + .select(from_protobuf_wrapper($"to_proto", name, descFilePathOpt) as 'durationMsg) - val actualFields = fromProtoDf.schema.fields.toList - val expectedFields = inputDf.schema.fields.toList - - assert(actualFields.size === expectedFields.size) - assert(actualFields === expectedFields) - assert(fromProtoDf.select("durationMsg.key").take(1).toSeq(0).get(0) - === inputDf.select("durationMsg.key").take(1).toSeq(0).get(0)) - assert(fromProtoDf.select("durationMsg.duration").take(1).toSeq(0).get(0) - === inputDf.select("durationMsg.duration").take(1).toSeq(0).get(0)) + val actualFields = fromProtoDf.schema.fields.toList + val expectedFields = inputDf.schema.fields.toList + assert(actualFields.size === expectedFields.size) + assert(actualFields === expectedFields) + assert(fromProtoDf.select("durationMsg.key").take(1).toSeq(0).get(0) + === inputDf.select("durationMsg.key").take(1).toSeq(0).get(0)) + assert(fromProtoDf.select("durationMsg.duration").take(1).toSeq(0).get(0) + === inputDf.select("durationMsg.duration").take(1).toSeq(0).get(0)) + } } } diff --git a/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufSerdeSuite.scala b/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufSerdeSuite.scala index 37c59743e77..efc02524e68 100644 --- a/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufSerdeSuite.scala +++ b/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufSerdeSuite.scala @@ -36,6 +36,7 @@ class ProtobufSerdeSuite extends SharedSparkSession { import ProtoSerdeSuite.MatchType._ val testFileDesc = testFile("protobuf/serde_suite.desc").replace("file:/", "/") + private val javaClassNamePrefix = "org.apache.spark.sql.protobuf.protos.SerdeSuiteProtos$" test("Test basic conversion") { withFieldMatchType { fieldMatch => @@ -96,7 +97,9 @@ class ProtobufSerdeSuite extends SharedSparkSession { } test("Fail to convert with deeply nested field type mismatch") { - val protoFile = ProtobufUtils.buildDescriptor(testFileDesc, "MissMatchTypeInDeepNested") + val protoFile = ProtobufUtils.buildDescriptorFromJavaClass( + s"${javaClassNamePrefix}MissMatchTypeInDeepNested" + ) val catalyst = new StructType().add("top", CATALYST_STRUCT) withFieldMatchType { fieldMatch => @@ -105,8 +108,8 @@ class ProtobufSerdeSuite extends SharedSparkSession { Deserializer, fieldMatch, s"Cannot convert Protobuf field 'top.foo.bar' to SQL field 'top.foo.bar' because schema " + - s"is incompatible (protoType = org.apache.spark.sql.protobuf.TypeMiss.bar " + - s"LABEL_OPTIONAL LONG INT64, sqlType = INT)".stripMargin, + s"is incompatible (protoType = org.apache.spark.sql.protobuf.protos.TypeMiss.bar " + + s"LABEL_OPTIONAL LONG INT64, sqlType = INT)", catalyst) assertFailedConversionMessage( diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index e5a48080e83..cc103e4ab00 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -716,8 +716,10 @@ object SparkProtobuf { dependencyOverrides += "com.google.protobuf" % "protobuf-java" % protoVersion, - (Compile / PB.targets) := Seq( - PB.gens.java -> (Compile / sourceManaged).value, + (Test / PB.protoSources) += (Test / sourceDirectory).value / "resources", + + (Test / PB.targets) := Seq( + PB.gens.java -> target.value / "generated-test-sources" ), (assembly / test) := { }, diff --git a/python/pyspark/sql/protobuf/functions.py b/python/pyspark/sql/protobuf/functions.py index 9f8b90095df..2059d868c7c 100644 --- a/python/pyspark/sql/protobuf/functions.py +++ b/python/pyspark/sql/protobuf/functions.py @@ -31,8 +31,8 @@ if TYPE_CHECKING: def from_protobuf( data: "ColumnOrName", - descFilePath: str, messageName: str, + descFilePath: str, options: Optional[Dict[str, str]] = None, ) -> Column: """ @@ -48,10 +48,10 @@ def from_protobuf( ---------- data : :class:`~pyspark.sql.Column` or str the binary column. - descFilePath : str - the protobuf descriptor in Message GeneratedMessageV3 format. messageName: str the protobuf message name to look for in descriptor file. + descFilePath : str + the protobuf descriptor in Message GeneratedMessageV3 format. options : dict, optional options to control how the protobuf record is parsed. @@ -80,10 +80,10 @@ def from_protobuf( ... f.flush() ... message_name = 'SimpleMessage' ... proto_df = df.select( - ... to_protobuf(df.value, desc_file_path, message_name).alias("value")) + ... to_protobuf(df.value, message_name, desc_file_path).alias("value")) ... proto_df.show(truncate=False) ... proto_df = proto_df.select( - ... from_protobuf(proto_df.value, desc_file_path, message_name).alias("value")) + ... from_protobuf(proto_df.value, message_name, desc_file_path).alias("value")) ... proto_df.show(truncate=False) +----------------------------------------+ |value | @@ -101,7 +101,7 @@ def from_protobuf( assert sc is not None and sc._jvm is not None try: jc = sc._jvm.org.apache.spark.sql.protobuf.functions.from_protobuf( - _to_java_column(data), descFilePath, messageName, options or {} + _to_java_column(data), messageName, descFilePath, options or {} ) except TypeError as e: if str(e) == "'JavaPackage' object is not callable": @@ -110,7 +110,7 @@ def from_protobuf( return Column(jc) -def to_protobuf(data: "ColumnOrName", descFilePath: str, messageName: str) -> Column: +def to_protobuf(data: "ColumnOrName", messageName: str, descFilePath: str) -> Column: """ Converts a column into binary of protobuf format. @@ -120,10 +120,10 @@ def to_protobuf(data: "ColumnOrName", descFilePath: str, messageName: str) -> Co ---------- data : :class:`~pyspark.sql.Column` or str the data column. - descFilePath : str - the protobuf descriptor in Message GeneratedMessageV3 format. messageName: str the protobuf message name to look for in descriptor file. + descFilePath : str + the protobuf descriptor in Message GeneratedMessageV3 format. Notes ----- @@ -150,7 +150,7 @@ def to_protobuf(data: "ColumnOrName", descFilePath: str, messageName: str) -> Co ... f.flush() ... message_name = 'SimpleMessage' ... proto_df = df.select( - ... to_protobuf(df.value, desc_file_path, message_name).alias("suite")) + ... to_protobuf(df.value, message_name, desc_file_path).alias("suite")) ... proto_df.show(truncate=False) +-------------------------------------------+ |suite | @@ -162,7 +162,7 @@ def to_protobuf(data: "ColumnOrName", descFilePath: str, messageName: str) -> Co assert sc is not None and sc._jvm is not None try: jc = sc._jvm.org.apache.spark.sql.protobuf.functions.to_protobuf( - _to_java_column(data), descFilePath, messageName + _to_java_column(data), messageName, descFilePath ) except TypeError as e: if str(e) == "'JavaPackage' object is not callable": --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org