parthchandra commented on code in PR #395: URL: https://github.com/apache/datafusion-comet/pull/395#discussion_r1607363360
########## common/src/main/java/org/apache/comet/parquet/CometParquetToSparkSchemaConverter.scala: ########## @@ -0,0 +1,403 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.parquet + +import org.apache.hadoop.conf.Configuration +import org.apache.parquet.io.{ColumnIO, GroupColumnIO, PrimitiveColumnIO} +import org.apache.parquet.schema._ +import org.apache.parquet.schema.LogicalTypeAnnotation._ +import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName._ +import org.apache.parquet.schema.Type.Repetition._ +import org.apache.spark.sql.catalyst.util.ResolveDefaultColumns.normalizeFieldName +import org.apache.spark.sql.execution.datasources.parquet.{ParquetColumn, ParquetToSparkSchemaConverter} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types._ + +class CometParquetToSparkSchemaConverter( + assumeBinaryIsString: Boolean = SQLConf.PARQUET_BINARY_AS_STRING.defaultValue.get, + assumeInt96IsTimestamp: Boolean = SQLConf.PARQUET_INT96_AS_TIMESTAMP.defaultValue.get, + caseSensitive: Boolean = SQLConf.CASE_SENSITIVE.defaultValue.get, + inferTimestampNTZ: Boolean = SQLConf.PARQUET_INFER_TIMESTAMP_NTZ_ENABLED.defaultValue.get, + nanosAsLong: Boolean = SQLConf.LEGACY_PARQUET_NANOS_AS_LONG.defaultValue.get) extends ParquetToSparkSchemaConverter { + + def this(conf: Configuration) = this( + assumeBinaryIsString = conf.get(SQLConf.PARQUET_BINARY_AS_STRING.key).toBoolean, + assumeInt96IsTimestamp = conf.get(SQLConf.PARQUET_INT96_AS_TIMESTAMP.key).toBoolean, + caseSensitive = conf.get(SQLConf.CASE_SENSITIVE.key).toBoolean, + inferTimestampNTZ = conf.get(SQLConf.PARQUET_INFER_TIMESTAMP_NTZ_ENABLED.key).toBoolean, + nanosAsLong = conf.get(SQLConf.LEGACY_PARQUET_NANOS_AS_LONG.key).toBoolean) + + override def convertField( + field: ColumnIO, + sparkReadType: Option[DataType] = None): ParquetColumn = { + val targetType = sparkReadType.map { + case udt: UserDefinedType[_] => udt.sqlType + case otherType => otherType + } + field match { + case primitiveColumn: PrimitiveColumnIO => convertPrimitiveField(primitiveColumn, targetType) + case groupColumn: GroupColumnIO => convertGroupField(groupColumn, targetType) + } + } + + private def convertPrimitiveField( + primitiveColumn: PrimitiveColumnIO, + sparkReadType: Option[DataType] = None): ParquetColumn = { + val parquetType = primitiveColumn.getType.asPrimitiveType() + val typeAnnotation = primitiveColumn.getType.getLogicalTypeAnnotation + val typeName = primitiveColumn.getPrimitive + + def typeString = + if (typeAnnotation == null) s"$typeName" else s"$typeName ($typeAnnotation)" + + def typeNotImplemented() = + throw new UnsupportedOperationException("unsupported Parquet type: " + typeString) + + def illegalType() = + throw new UnsupportedOperationException("Illegal Parquet type: " + typeString) + + // When maxPrecision = -1, we skip precision range check, and always respect the precision + // specified in field.getDecimalMetadata. This is useful when interpreting decimal types stored + // as binaries with variable lengths. + def makeDecimalType(maxPrecision: Int = -1): DecimalType = { + val decimalLogicalTypeAnnotation = typeAnnotation + .asInstanceOf[DecimalLogicalTypeAnnotation] + val precision = decimalLogicalTypeAnnotation.getPrecision + val scale = decimalLogicalTypeAnnotation.getScale + + CometParquetSchemaConverter.checkConversionRequirement( + maxPrecision == -1 || 1 <= precision && precision <= maxPrecision, + s"Invalid decimal precision: $typeName cannot store $precision digits (max $maxPrecision)") + + DecimalType(precision, scale) + } + + val sparkType = sparkReadType.getOrElse(typeName match { + case BOOLEAN => BooleanType + + case FLOAT => FloatType + + case DOUBLE => DoubleType + + case INT32 => + typeAnnotation match { + case intTypeAnnotation: IntLogicalTypeAnnotation if intTypeAnnotation.isSigned => + intTypeAnnotation.getBitWidth match { + case 8 => ByteType + case 16 => ShortType + case 32 => IntegerType + case _ => illegalType() + } + case null => IntegerType + case _: DateLogicalTypeAnnotation => DateType + case _: DecimalLogicalTypeAnnotation => makeDecimalType(Decimal.MAX_INT_DIGITS) + case intTypeAnnotation: IntLogicalTypeAnnotation if !intTypeAnnotation.isSigned => + intTypeAnnotation.getBitWidth match { + case 8 => ShortType + case 16 => IntegerType + case 32 => LongType + case _ => illegalType() + } + case t: TimestampLogicalTypeAnnotation if t.getUnit == TimeUnit.MILLIS => + typeNotImplemented() + case _ => illegalType() + } + + case INT64 => + typeAnnotation match { + case intTypeAnnotation: IntLogicalTypeAnnotation if intTypeAnnotation.isSigned => + intTypeAnnotation.getBitWidth match { + case 64 => LongType + case _ => illegalType() + } + case null => LongType + case _: DecimalLogicalTypeAnnotation => makeDecimalType(Decimal.MAX_LONG_DIGITS) + case intTypeAnnotation: IntLogicalTypeAnnotation if !intTypeAnnotation.isSigned => + intTypeAnnotation.getBitWidth match { + // The precision to hold the largest unsigned long is: + // `java.lang.Long.toUnsignedString(-1).length` = 20 + case 64 => DecimalType(20, 0) + case _ => illegalType() + } + case timestamp: TimestampLogicalTypeAnnotation + if timestamp.getUnit == TimeUnit.MICROS || timestamp.getUnit == TimeUnit.MILLIS => + val inferTimestampNTZ = SQLConf.PARQUET_INFER_TIMESTAMP_NTZ_ENABLED.defaultValue.get + if (timestamp.isAdjustedToUTC || !inferTimestampNTZ) { + TimestampType + } else { + TimestampNTZType + } + // SPARK-40819: NANOS are not supported as a Timestamp, convert to LongType without + // timezone awareness to address behaviour regression introduced by SPARK-34661 + case timestamp: TimestampLogicalTypeAnnotation + if timestamp.getUnit == TimeUnit.NANOS && SQLConf.LEGACY_PARQUET_NANOS_AS_LONG.defaultValue.get => + LongType + case _ => illegalType() + } + + case INT96 => + CometParquetSchemaConverter.checkConversionRequirement( + SQLConf.PARQUET_INT96_AS_TIMESTAMP.defaultValue.get, + "INT96 is not supported unless it's interpreted as timestamp. " + + s"Please try to set ${SQLConf.PARQUET_INT96_AS_TIMESTAMP.key} to true.") + TimestampType + + case BINARY => + typeAnnotation match { + case _: StringLogicalTypeAnnotation | _: EnumLogicalTypeAnnotation | + _: JsonLogicalTypeAnnotation => StringType + case null if SQLConf.PARQUET_BINARY_AS_STRING.defaultValue.get => StringType + case null => BinaryType + case _: BsonLogicalTypeAnnotation => BinaryType + case _: DecimalLogicalTypeAnnotation => makeDecimalType() + case _ => illegalType() + } + + case FIXED_LEN_BYTE_ARRAY => + typeAnnotation match { + case _: DecimalLogicalTypeAnnotation => + makeDecimalType(Decimal.maxPrecisionForBytes(parquetType.getTypeLength)) + case _: UUIDLogicalTypeAnnotation => StringType Review Comment: Are we adding this just to test? Or is this likely to be useful in other places? Also, instead of copying, could we not just extend the Spark class and override `convertField`? WE can then call our impl of convertPrimitiveField for UUID and let the parent implementation handle the rest? We are likely to miss changes made in Spark if we make a copy. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: github-unsubscr...@datafusion.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org --------------------------------------------------------------------- To unsubscribe, e-mail: github-unsubscr...@datafusion.apache.org For additional commands, e-mail: github-h...@datafusion.apache.org