parthchandra commented on code in PR #395:
URL: https://github.com/apache/datafusion-comet/pull/395#discussion_r1607363360


##########
common/src/main/java/org/apache/comet/parquet/CometParquetToSparkSchemaConverter.scala:
##########
@@ -0,0 +1,403 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.comet.parquet
+
+import org.apache.hadoop.conf.Configuration
+import org.apache.parquet.io.{ColumnIO, GroupColumnIO, PrimitiveColumnIO}
+import org.apache.parquet.schema._
+import org.apache.parquet.schema.LogicalTypeAnnotation._
+import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName._
+import org.apache.parquet.schema.Type.Repetition._
+import 
org.apache.spark.sql.catalyst.util.ResolveDefaultColumns.normalizeFieldName
+import org.apache.spark.sql.execution.datasources.parquet.{ParquetColumn, 
ParquetToSparkSchemaConverter}
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.types._
+
+class CometParquetToSparkSchemaConverter(
+    assumeBinaryIsString: Boolean = 
SQLConf.PARQUET_BINARY_AS_STRING.defaultValue.get,
+    assumeInt96IsTimestamp: Boolean = 
SQLConf.PARQUET_INT96_AS_TIMESTAMP.defaultValue.get,
+    caseSensitive: Boolean = SQLConf.CASE_SENSITIVE.defaultValue.get,
+    inferTimestampNTZ: Boolean = 
SQLConf.PARQUET_INFER_TIMESTAMP_NTZ_ENABLED.defaultValue.get,
+    nanosAsLong: Boolean = 
SQLConf.LEGACY_PARQUET_NANOS_AS_LONG.defaultValue.get) extends 
ParquetToSparkSchemaConverter {
+
+  def this(conf: Configuration) = this(
+    assumeBinaryIsString = 
conf.get(SQLConf.PARQUET_BINARY_AS_STRING.key).toBoolean,
+    assumeInt96IsTimestamp = 
conf.get(SQLConf.PARQUET_INT96_AS_TIMESTAMP.key).toBoolean,
+    caseSensitive = conf.get(SQLConf.CASE_SENSITIVE.key).toBoolean,
+    inferTimestampNTZ = 
conf.get(SQLConf.PARQUET_INFER_TIMESTAMP_NTZ_ENABLED.key).toBoolean,
+    nanosAsLong = conf.get(SQLConf.LEGACY_PARQUET_NANOS_AS_LONG.key).toBoolean)
+
+  override def convertField(
+      field: ColumnIO,
+      sparkReadType: Option[DataType] = None): ParquetColumn = {
+    val targetType = sparkReadType.map {
+      case udt: UserDefinedType[_] => udt.sqlType
+      case otherType => otherType
+    }
+    field match {
+      case primitiveColumn: PrimitiveColumnIO => 
convertPrimitiveField(primitiveColumn, targetType)
+      case groupColumn: GroupColumnIO => convertGroupField(groupColumn, 
targetType)
+    }
+  }
+
+  private def convertPrimitiveField(
+      primitiveColumn: PrimitiveColumnIO,
+      sparkReadType: Option[DataType] = None): ParquetColumn = {
+    val parquetType = primitiveColumn.getType.asPrimitiveType()
+    val typeAnnotation = primitiveColumn.getType.getLogicalTypeAnnotation
+    val typeName = primitiveColumn.getPrimitive
+
+    def typeString =
+      if (typeAnnotation == null) s"$typeName" else s"$typeName 
($typeAnnotation)"
+
+    def typeNotImplemented() =
+      throw new UnsupportedOperationException("unsupported Parquet type: " + 
typeString)
+
+    def illegalType() =
+      throw new UnsupportedOperationException("Illegal Parquet type: " + 
typeString)
+
+    // When maxPrecision = -1, we skip precision range check, and always 
respect the precision
+    // specified in field.getDecimalMetadata.  This is useful when 
interpreting decimal types stored
+    // as binaries with variable lengths.
+    def makeDecimalType(maxPrecision: Int = -1): DecimalType = {
+      val decimalLogicalTypeAnnotation = typeAnnotation
+        .asInstanceOf[DecimalLogicalTypeAnnotation]
+      val precision = decimalLogicalTypeAnnotation.getPrecision
+      val scale = decimalLogicalTypeAnnotation.getScale
+
+      CometParquetSchemaConverter.checkConversionRequirement(
+        maxPrecision == -1 || 1 <= precision && precision <= maxPrecision,
+        s"Invalid decimal precision: $typeName cannot store $precision digits 
(max $maxPrecision)")
+
+      DecimalType(precision, scale)
+    }
+
+    val sparkType = sparkReadType.getOrElse(typeName match {
+      case BOOLEAN => BooleanType
+
+      case FLOAT => FloatType
+
+      case DOUBLE => DoubleType
+
+      case INT32 =>
+        typeAnnotation match {
+          case intTypeAnnotation: IntLogicalTypeAnnotation if 
intTypeAnnotation.isSigned =>
+            intTypeAnnotation.getBitWidth match {
+              case 8 => ByteType
+              case 16 => ShortType
+              case 32 => IntegerType
+              case _ => illegalType()
+            }
+          case null => IntegerType
+          case _: DateLogicalTypeAnnotation => DateType
+          case _: DecimalLogicalTypeAnnotation => 
makeDecimalType(Decimal.MAX_INT_DIGITS)
+          case intTypeAnnotation: IntLogicalTypeAnnotation if 
!intTypeAnnotation.isSigned =>
+            intTypeAnnotation.getBitWidth match {
+              case 8 => ShortType
+              case 16 => IntegerType
+              case 32 => LongType
+              case _ => illegalType()
+            }
+          case t: TimestampLogicalTypeAnnotation if t.getUnit == 
TimeUnit.MILLIS =>
+            typeNotImplemented()
+          case _ => illegalType()
+        }
+
+      case INT64 =>
+        typeAnnotation match {
+          case intTypeAnnotation: IntLogicalTypeAnnotation if 
intTypeAnnotation.isSigned =>
+            intTypeAnnotation.getBitWidth match {
+              case 64 => LongType
+              case _ => illegalType()
+            }
+          case null => LongType
+          case _: DecimalLogicalTypeAnnotation => 
makeDecimalType(Decimal.MAX_LONG_DIGITS)
+          case intTypeAnnotation: IntLogicalTypeAnnotation if 
!intTypeAnnotation.isSigned =>
+            intTypeAnnotation.getBitWidth match {
+              // The precision to hold the largest unsigned long is:
+              // `java.lang.Long.toUnsignedString(-1).length` = 20
+              case 64 => DecimalType(20, 0)
+              case _ => illegalType()
+            }
+          case timestamp: TimestampLogicalTypeAnnotation
+            if timestamp.getUnit == TimeUnit.MICROS || timestamp.getUnit == 
TimeUnit.MILLIS =>
+            val inferTimestampNTZ = 
SQLConf.PARQUET_INFER_TIMESTAMP_NTZ_ENABLED.defaultValue.get
+            if (timestamp.isAdjustedToUTC || !inferTimestampNTZ) {
+              TimestampType
+            } else {
+              TimestampNTZType
+            }
+          // SPARK-40819: NANOS are not supported as a Timestamp, convert to 
LongType without
+          // timezone awareness to address behaviour regression introduced by 
SPARK-34661
+          case timestamp: TimestampLogicalTypeAnnotation
+            if timestamp.getUnit == TimeUnit.NANOS && 
SQLConf.LEGACY_PARQUET_NANOS_AS_LONG.defaultValue.get =>
+            LongType
+          case _ => illegalType()
+        }
+
+      case INT96 =>
+        CometParquetSchemaConverter.checkConversionRequirement(
+          SQLConf.PARQUET_INT96_AS_TIMESTAMP.defaultValue.get,
+          "INT96 is not supported unless it's interpreted as timestamp. " +
+            s"Please try to set ${SQLConf.PARQUET_INT96_AS_TIMESTAMP.key} to 
true.")
+        TimestampType
+
+      case BINARY =>
+        typeAnnotation match {
+          case _: StringLogicalTypeAnnotation | _: EnumLogicalTypeAnnotation |
+               _: JsonLogicalTypeAnnotation => StringType
+          case null if SQLConf.PARQUET_BINARY_AS_STRING.defaultValue.get => 
StringType
+          case null => BinaryType
+          case _: BsonLogicalTypeAnnotation => BinaryType
+          case _: DecimalLogicalTypeAnnotation => makeDecimalType()
+          case _ => illegalType()
+        }
+
+      case FIXED_LEN_BYTE_ARRAY =>
+        typeAnnotation match {
+          case _: DecimalLogicalTypeAnnotation =>
+            
makeDecimalType(Decimal.maxPrecisionForBytes(parquetType.getTypeLength))
+          case _: UUIDLogicalTypeAnnotation => StringType

Review Comment:
   Are we adding this just to test? Or is this likely to be useful in other 
places?
   Also, instead of copying, could we not just extend the Spark class and 
override `convertField`? WE can then call our impl of convertPrimitiveField for 
UUID and let the parent implementation handle the rest? 
   We are likely to miss changes made in Spark if we make a copy.  



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscr...@datafusion.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: github-unsubscr...@datafusion.apache.org
For additional commands, e-mail: github-h...@datafusion.apache.org

Reply via email to