huaxingao commented on a change in pull request #32049: URL: https://github.com/apache/spark/pull/32049#discussion_r607421850
########## File path: sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala ########## @@ -127,4 +147,360 @@ object ParquetUtils { file.getName == ParquetFileWriter.PARQUET_COMMON_METADATA_FILE || file.getName == ParquetFileWriter.PARQUET_METADATA_FILE } + + /** + * When the Aggregates (Max/Min/Count) are pushed down to parquet, we don't need to + * createRowBaseReader to read data from parquet and aggregate at spark layer. Instead we want + * to calculate the Aggregates (Max/Min/Count) result using the statistics information + * from parquet footer file, and then construct an InternalRow from these Aggregate results. + * + * @return Aggregate results in the format of InternalRow + */ + private[sql] def aggResultToSparkInternalRows( + footer: ParquetMetadata, + parquetTypes: Seq[PrimitiveType.PrimitiveTypeName], + values: Seq[Any], + dataSchema: StructType, + datetimeRebaseModeInRead: String, + int96RebaseModeInRead: String, + convertTz: Option[ZoneId]): InternalRow = { + val mutableRow = new SpecificInternalRow(dataSchema.fields.map(x => x.dataType)) + val footerFileMetaData = footer.getFileMetaData + val datetimeRebaseMode = DataSourceUtils.datetimeRebaseMode( + footerFileMetaData.getKeyValueMetaData.get, + datetimeRebaseModeInRead) + val int96RebaseMode = DataSourceUtils.int96RebaseMode( + footerFileMetaData.getKeyValueMetaData.get, + int96RebaseModeInRead) + parquetTypes.zipWithIndex.map { + case (PrimitiveType.PrimitiveTypeName.INT32, i) => + if (values(i) == null) { + mutableRow.setNullAt(i) + } else { + dataSchema.fields(i).dataType match { + case b: ByteType => + mutableRow.setByte(i, values(i).asInstanceOf[Integer].toByte) + case s: ShortType => + mutableRow.setShort(i, values(i).asInstanceOf[Integer].toShort) + case int: IntegerType => + mutableRow.setInt(i, values(i).asInstanceOf[Integer]) + case d: DateType => + val dateRebaseFunc = DataSourceUtils.creteDateRebaseFuncInRead( + datetimeRebaseMode, "Parquet") + mutableRow.update(i, dateRebaseFunc(values(i).asInstanceOf[Integer])) + case d: DecimalType => + val decimal = Decimal(values(i).asInstanceOf[Integer].toLong, d.precision, d.scale) + mutableRow.setDecimal(i, decimal, d.precision) + case _ => throw new IllegalArgumentException("Unexpected type for INT32") + } + } + case (PrimitiveType.PrimitiveTypeName.INT64, i) => + if (values(i) == null) { + mutableRow.setNullAt(i) + } else { + dataSchema.fields(i).dataType match { + case long: LongType => + mutableRow.setLong(i, values(i).asInstanceOf[Long]) + case d: DecimalType => + val decimal = Decimal(values(i).asInstanceOf[Long], d.precision, d.scale) + mutableRow.setDecimal(i, decimal, d.precision) + case _ => throw new IllegalArgumentException("Unexpected type for INT64") + } + } + case (PrimitiveType.PrimitiveTypeName.INT96, i) => + if (values(i) == null) { + mutableRow.setNullAt(i) + } else { + dataSchema.fields(i).dataType match { + case l: LongType => + mutableRow.setLong(i, values(i).asInstanceOf[Long]) + case d: TimestampType => + val int96RebaseFunc = DataSourceUtils.creteTimestampRebaseFuncInRead( + int96RebaseMode, "Parquet INT96") + val julianMicros = + ParquetRowConverter.binaryToSQLTimestamp(values(i).asInstanceOf[Binary]) + val gregorianMicros = int96RebaseFunc(julianMicros) + val adjTime = + convertTz.map(DateTimeUtils.convertTz(gregorianMicros, _, ZoneOffset.UTC)) + .getOrElse(gregorianMicros) + mutableRow.setLong(i, adjTime) + case _ => throw new IllegalArgumentException("Unexpected type for INT96") + } + } + case (PrimitiveType.PrimitiveTypeName.FLOAT, i) => + if (values(i) == null) { + mutableRow.setNullAt(i) + } else { + mutableRow.setFloat(i, values(i).asInstanceOf[Float]) + } + case (PrimitiveType.PrimitiveTypeName.DOUBLE, i) => + if (values(i) == null) { + mutableRow.setNullAt(i) + } else { + mutableRow.setDouble(i, values(i).asInstanceOf[Double]) + } + case (PrimitiveType.PrimitiveTypeName.BOOLEAN, i) => + if (values(i) == null) { + mutableRow.setNullAt(i) + } else { + mutableRow.setBoolean(i, values(i).asInstanceOf[Boolean]) + } + case (PrimitiveType.PrimitiveTypeName.BINARY, i) => + if (values(i) == null) { + mutableRow.setNullAt(i) + } else { + val bytes = values(i).asInstanceOf[Binary].getBytes + dataSchema.fields(i).dataType match { + case s: StringType => + mutableRow.update(i, UTF8String.fromBytes(bytes)) + case b: BinaryType => + mutableRow.update(i, bytes) + case d: DecimalType => + val decimal = + Decimal(new BigDecimal(new BigInteger(bytes), d.scale), d.precision, d.scale) + mutableRow.setDecimal(i, decimal, d.precision) + case _ => throw new IllegalArgumentException("Unexpected type for Binary") + } + } + case (PrimitiveType.PrimitiveTypeName.FIXED_LEN_BYTE_ARRAY, i) => + if (values(i) == null) { + mutableRow.setNullAt(i) + } else { + val bytes = values(i).asInstanceOf[Binary].getBytes + dataSchema.fields(i).dataType match { + case d: DecimalType => + val decimal = + Decimal(new BigDecimal(new BigInteger(bytes), d.scale), d.precision, d.scale) + mutableRow.setDecimal(i, decimal, d.precision) + case _ => throw new IllegalArgumentException("Unexpected type for FIXED_LEN_BYTE_ARRAY") + } + } + case _ => + throw new IllegalArgumentException("Unexpected parquet type name") + } + mutableRow + } + + /** + * When the Aggregates (Max/Min/Count) are pushed down to parquet, in the case of + * PARQUET_VECTORIZED_READER_ENABLED sets to true, we don't need buildColumnarReader + * to read data from parquet and aggregate at spark layer. Instead we want + * to calculate the Aggregates (Max/Min/Count) result using the statistics information + * from parquet footer file, and then construct a ColumnarBatch from these Aggregate results. + * + * @return Aggregate results in the format of ColumnarBatch + */ + private[sql] def aggResultToSparkColumnarBatch( + footer: ParquetMetadata, + parquetTypes: Seq[PrimitiveType.PrimitiveTypeName], + values: Seq[Any], + dataSchema: StructType, + offHeap: Boolean, + datetimeRebaseModeInRead: String, + int96RebaseModeInRead: String, + convertTz: Option[ZoneId]): ColumnarBatch = { + val capacity = 4 * 1024 + val footerFileMetaData = footer.getFileMetaData + val datetimeRebaseMode = DataSourceUtils.datetimeRebaseMode( + footerFileMetaData.getKeyValueMetaData.get, + datetimeRebaseModeInRead) + val int96RebaseMode = DataSourceUtils.int96RebaseMode( + footerFileMetaData.getKeyValueMetaData.get, + int96RebaseModeInRead) + val columnVectors = if (offHeap) { + OffHeapColumnVector.allocateColumns(capacity, dataSchema) + } else { + OnHeapColumnVector.allocateColumns(capacity, dataSchema) + } + + parquetTypes.zipWithIndex.map { + case (PrimitiveType.PrimitiveTypeName.INT32, i) => + if (values(i) == null) { + columnVectors(i).appendNull() + } else { + dataSchema.fields(i).dataType match { + case b: ByteType => + columnVectors(i).appendByte(values(i).asInstanceOf[Integer].toByte) + case s: ShortType => + columnVectors(i).appendShort(values(i).asInstanceOf[Integer].toShort) + case int: IntegerType => + columnVectors(i).appendInt(values(i).asInstanceOf[Integer]) + case d: DateType => + val dateRebaseFunc = DataSourceUtils.creteDateRebaseFuncInRead( + datetimeRebaseMode, "Parquet") + columnVectors(i).appendInt(dateRebaseFunc(values(i).asInstanceOf[Integer])) + case _ => throw new IllegalArgumentException("Unexpected type for INT32") + } + } + case (PrimitiveType.PrimitiveTypeName.INT64, i) => + if (values(i) == null) { + columnVectors(i).appendNull() + } else { + columnVectors(i).appendLong(values(i).asInstanceOf[Long]) + } + case (PrimitiveType.PrimitiveTypeName.INT96, i) => + if (values(i) == null) { + columnVectors(i).appendNull() + } else { + dataSchema.fields(i).dataType match { + case l: LongType => + columnVectors(i).appendLong(values(i).asInstanceOf[Long]) + case d: TimestampType => + val int96RebaseFunc = DataSourceUtils.creteTimestampRebaseFuncInRead( + int96RebaseMode, "Parquet INT96") + val julianMicros = + ParquetRowConverter.binaryToSQLTimestamp(values(i).asInstanceOf[Binary]) + val gregorianMicros = int96RebaseFunc(julianMicros) + val adjTime = + convertTz.map(DateTimeUtils.convertTz(gregorianMicros, _, ZoneOffset.UTC)) + .getOrElse(gregorianMicros) + columnVectors(i).appendLong(adjTime) + case _ => throw new IllegalArgumentException("Unexpected type for INT96") + } + } + case (PrimitiveType.PrimitiveTypeName.FLOAT, i) => + if (values(i) == null) { + columnVectors(i).appendNull() + } else { + columnVectors(i).appendFloat(values(i).asInstanceOf[Float]) + } + case (PrimitiveType.PrimitiveTypeName.DOUBLE, i) => + if (values(i) == null) { + columnVectors(i).appendNull() + } else { + columnVectors(i).appendDouble(values(i).asInstanceOf[Double]) + } + case (PrimitiveType.PrimitiveTypeName.BINARY, i) => + if (values(i) == null) { + columnVectors(i).appendNull() + } else { + val bytes = values(i).asInstanceOf[Binary].getBytes + columnVectors(i).putByteArray(0, bytes, 0, bytes.length) + } + case (PrimitiveType.PrimitiveTypeName.FIXED_LEN_BYTE_ARRAY, i) => + if (values(i) == null) { + columnVectors(i).appendNull() + } else { + val bytes = values(i).asInstanceOf[Binary].getBytes + columnVectors(i).putByteArray(0, bytes, 0, bytes.length) + } + case (PrimitiveType.PrimitiveTypeName.BOOLEAN, i) => + if (values(i) == null) { + columnVectors(i).appendNull() + } else { + columnVectors(i).appendBoolean(values(i).asInstanceOf[Boolean]) + } + case _ => + throw new IllegalArgumentException("Unexpected parquet type name") + } + new ColumnarBatch(columnVectors.asInstanceOf[Array[ColumnVector]], 1) + } + + /** + * Calculate the pushed down Aggregates (Max/Min/Count) result using the statistics + * information from parquet footer file. + * + * @return A tuple of `Array[PrimitiveType.PrimitiveTypeName]` and Array[Any]. + * The first element is the PrimitiveTypeName of the Aggregate column, + * and the second element is the aggregated value. + */ + private[sql] def getPushedDownAggResult( + footer: ParquetMetadata, + dataSchema: StructType, + aggregation: Aggregation) + : (Array[PrimitiveType.PrimitiveTypeName], Array[Any]) = { + val footerFileMetaData = footer.getFileMetaData + val fields = footerFileMetaData.getSchema.getFields + val blocks = footer.getBlocks() + val typesBuilder = ArrayBuilder.make[PrimitiveType.PrimitiveTypeName] + val valuesBuilder = ArrayBuilder.make[Any] + + for (i <- 0 until aggregation.aggregateExpressions.size) { + var value: Any = None + var rowCount = 0L + var isCount = false + var index = 0 + blocks.forEach { block => + val blockMetaData = block.getColumns() + aggregation.aggregateExpressions(i) match { + case Max(col, _) => + index = dataSchema.fieldNames.toList.indexOf(col) + val currentMax = getCurrentBlockMaxOrMin(footer, blockMetaData, index, true) + if (currentMax != None && + (value == None || currentMax.asInstanceOf[Comparable[Any]].compareTo(value) > 0)) { + value = currentMax + } + + case Min(col, _) => + index = dataSchema.fieldNames.toList.indexOf(col) + val currentMin = getCurrentBlockMaxOrMin(footer, blockMetaData, index, false) + if (currentMin != None && + (value == None || currentMin.asInstanceOf[Comparable[Any]].compareTo(value) < 0)) { + value = currentMin + } + + case Count(col, _, _) => + index = dataSchema.fieldNames.toList.indexOf(col) + rowCount += getRowCountFromParquetMetadata(footer) + if (!col.equals("1")) { // "1" is for count(*) + rowCount -= getNumNulls(footer, blockMetaData, index) + } + isCount = true + + case _ => + } + } + if (isCount) { + valuesBuilder += rowCount + typesBuilder += PrimitiveType.PrimitiveTypeName.INT96 + } else { + valuesBuilder += value + typesBuilder += fields.get(index).asPrimitiveType.getPrimitiveTypeName + } + } + (typesBuilder.result(), valuesBuilder.result()) + } + + /** + * get the Max or Min value for ith column in the current block + * + * @return the Max or Min value + */ + private def getCurrentBlockMaxOrMin( + footer: ParquetMetadata, + columnChunkMetaData: util.List[ColumnChunkMetaData], + i: Int, + isMax: Boolean): Any = { + val parquetType = footer.getFileMetaData.getSchema.getType(i) + if (!parquetType.isPrimitive) { + throw new IllegalArgumentException("Unsupported type : " + parquetType.toString) + } + val statistics = columnChunkMetaData.get(i).getStatistics() + if (isMax) statistics.genericGetMax() else statistics.genericGetMin() Review comment: Good question. I actually need to check if Parquet returns the min/max statistics. If not, I will either throw Exception or fall back to the no push down way. I think fall back is a better solution. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org