danielxjd commented on a change in pull request #12223: URL: https://github.com/apache/beam/pull/12223#discussion_r458297447
########## File path: sdks/java/io/parquet/src/main/java/org/apache/beam/sdk/io/parquet/ParquetIO.java ########## @@ -235,12 +283,164 @@ public ReadFiles withAvroDataModel(GenericData model) { return toBuilder().setAvroDataModel(model).build(); } + public ReadFiles withSplit() { + return toBuilder().setSplit(true).build(); + } + @Override public PCollection<GenericRecord> expand(PCollection<FileIO.ReadableFile> input) { checkNotNull(getSchema(), "Schema can not be null"); - return input - .apply(ParDo.of(new ReadFn(getAvroDataModel()))) - .setCoder(AvroCoder.of(getSchema())); + if (!getSplit()) { + return input + .apply(ParDo.of(new SplitReadFn(getAvroDataModel()))) + .setCoder(AvroCoder.of(getSchema())); + } else { + return input + .apply(ParDo.of(new ReadFn(getAvroDataModel()))) + .setCoder(AvroCoder.of(getSchema())); + } + } + + @DoFn.BoundedPerElement + static class SplitReadFn extends DoFn<FileIO.ReadableFile, GenericRecord> { + private Class<? extends GenericData> modelClass; + private static final Logger LOG = LoggerFactory.getLogger(SplitReadFn.class); + ReadSupport<GenericRecord> readSupport; + + SplitReadFn(GenericData model) { + this.modelClass = model != null ? model.getClass() : null; + } + + private static <K, V> Map<K, Set<V>> toSetMultiMap(Map<K, V> map) { + Map<K, Set<V>> setMultiMap = new HashMap<K, Set<V>>(); + for (Map.Entry<K, V> entry : map.entrySet()) { + Set<V> set = new HashSet<V>(); + set.add(entry.getValue()); + setMultiMap.put(entry.getKey(), Collections.unmodifiableSet(set)); + } + return Collections.unmodifiableMap(setMultiMap); + } + + private InputFile getInputFile(FileIO.ReadableFile file) throws IOException { + if (!file.getMetadata().isReadSeekEfficient()) { + throw new RuntimeException( + String.format("File has to be seekable: %s", file.getMetadata().resourceId())); + } + return new BeamParquetInputFile(file.openSeekable()); + } + + @ProcessElement + public void processElement( + @Element FileIO.ReadableFile file, + RestrictionTracker<OffsetRange, Long> tracker, + OutputReceiver<GenericRecord> outputReceiver) + throws Exception { + ReadSupport<GenericRecord> readSupport; + InputFile inputFile = getInputFile(file); + Configuration conf = setConf(); + GenericData model = null; + if (modelClass != null) { + model = (GenericData) modelClass.getMethod("get").invoke(null); + } + readSupport = new AvroReadSupport<GenericRecord>(model); + ParquetReadOptions options = HadoopReadOptions.builder(conf).build(); + ParquetFileReader reader = ParquetFileReader.open(inputFile, options); + Filter filter = checkNotNull(options.getRecordFilter(), "filter"); + conf = ((HadoopReadOptions) options).getConf(); + for (String property : options.getPropertyNames()) { + conf.set(property, options.getProperty(property)); + } + FileMetaData parquetFileMetadata = reader.getFooter().getFileMetaData(); + MessageType fileSchema = parquetFileMetadata.getSchema(); + Map<String, String> fileMetadata = parquetFileMetadata.getKeyValueMetaData(); + + ReadSupport.ReadContext readContext = + readSupport.init(new InitContext(conf, toSetMultiMap(fileMetadata), fileSchema)); + ColumnIOFactory columnIOFactory = new ColumnIOFactory(parquetFileMetadata.getCreatedBy()); + MessageType requestedSchema = readContext.getRequestedSchema(); + RecordMaterializer<GenericRecord> recordConverter = + readSupport.prepareForRead(conf, fileMetadata, fileSchema, readContext); + boolean strictTypeChecking = options.isEnabled(STRICT_TYPE_CHECKING, true); + boolean filterRecords = options.useRecordFilter(); + reader.setRequestedSchema(requestedSchema); + MessageColumnIO columnIO = + columnIOFactory.getColumnIO(requestedSchema, fileSchema, strictTypeChecking); + long currentBlock = tracker.currentRestriction().getFrom(); + for (int i = 0; i < currentBlock; i++) { + reader.skipNextRowGroup(); + } + while (tracker.tryClaim(currentBlock)) { + PageReadStore pages = reader.readNextRowGroup(); + currentBlock += 1; + RecordReader<GenericRecord> recordReader = + columnIO.getRecordReader( + pages, recordConverter, filterRecords ? filter : FilterCompat.NOOP); + GenericRecord read; + long currentRow = 0; + long totalRows = pages.getRowCount(); + while (currentRow < totalRows) { + outputReceiver.output(recordReader.read()); + currentRow += 1; + } + } + } + + private Configuration setConf() throws Exception { + Configuration conf = new Configuration(); + GenericData model = null; + if (modelClass != null) { + model = (GenericData) modelClass.getMethod("get").invoke(null); + } + if (model != null + && (model.getClass() == GenericData.class || model.getClass() == SpecificData.class)) { + conf.setBoolean(AvroReadSupport.AVRO_COMPATIBILITY, true); + } else { + conf.setBoolean(AvroReadSupport.AVRO_COMPATIBILITY, false); + } + return conf; + } + + @GetInitialRestriction + public OffsetRange getInitialRestriction(@Element FileIO.ReadableFile file) throws Exception { + InputFile inputFile = getInputFile(file); + Configuration conf = setConf(); + ParquetReadOptions options = HadoopReadOptions.builder(conf).build(); + ParquetFileReader reader = ParquetFileReader.open(inputFile, options); + return new OffsetRange(0, reader.getRowGroups().size()); + } + + @SplitRestriction + public void split(@Restriction OffsetRange restriction, OutputReceiver<OffsetRange> out) { + for (OffsetRange range : restriction.split(1, 0)) { + out.output(range); + } + } + + @NewTracker + public OffsetRangeTracker newTracker(@Restriction OffsetRange restriction) { + return new OffsetRangeTracker(restriction); + } + + @GetSize + public double getSize(@Element FileIO.ReadableFile file, @Restriction OffsetRange restriction) + throws Exception { + InputFile inputFile = getInputFile(file); + Configuration conf = setConf(); + ParquetReadOptions options = HadoopReadOptions.builder(conf).build(); + ParquetFileReader reader = ParquetFileReader.open(inputFile, options); + if (restriction == null) { + return reader.getRecordCount(); + } else { + long start = restriction.getFrom(); + long end = restriction.getTo(); + List<BlockMetaData> blocks = reader.getRowGroups(); + double size = 0; + for (long i = start; i < end; i++) { + size += blocks.get((int) i).getRowCount(); + } + return size; Review comment: I think get size is used to get an idea of the current process and for the purpose of optimized split. Since the size of each row is similar for a given file, the number of row is a fairly accurate representation of the work, multiply with a constant does not change too much when we compare. The previous methods of printing the progress also used number of rows(records) in their progress report. ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org