danielxjd commented on a change in pull request #12223:
URL: https://github.com/apache/beam/pull/12223#discussion_r464703969



##########
File path: 
sdks/java/io/parquet/src/main/java/org/apache/beam/sdk/io/parquet/ParquetIO.java
##########
@@ -230,12 +281,271 @@ public ReadFiles withAvroDataModel(GenericData model) {
       return toBuilder().setAvroDataModel(model).build();
     }
 
+    public ReadFiles withSplit() {
+      return toBuilder().setSplit(true).build();
+    }
+
     @Override
     public PCollection<GenericRecord> expand(PCollection<FileIO.ReadableFile> 
input) {
       checkNotNull(getSchema(), "Schema can not be null");
-      return input
-          .apply(ParDo.of(new ReadFn(getAvroDataModel())))
-          .setCoder(AvroCoder.of(getSchema()));
+      if (!getSplit()) {
+        return input
+            .apply(ParDo.of(new SplitReadFn(getAvroDataModel())))
+            .setCoder(AvroCoder.of(getSchema()));
+      } else {
+        return input
+            .apply(ParDo.of(new ReadFn(getAvroDataModel())))
+            .setCoder(AvroCoder.of(getSchema()));
+      }
+    }
+
+    @DoFn.BoundedPerElement
+    static class SplitReadFn extends DoFn<FileIO.ReadableFile, GenericRecord> {
+      private Class<? extends GenericData> modelClass;
+      private static final Logger LOG = 
LoggerFactory.getLogger(SplitReadFn.class);
+      private static final long SPLIT_LIMIT = 64000000;
+      ReadSupport<GenericRecord> readSupport;
+
+      SplitReadFn(GenericData model) {
+        this.modelClass = model != null ? model.getClass() : null;
+      }
+
+      private static <K, V> Map<K, Set<V>> toSetMultiMap(Map<K, V> map) {
+        Map<K, Set<V>> setMultiMap = new HashMap<K, Set<V>>();
+        for (Map.Entry<K, V> entry : map.entrySet()) {
+          Set<V> set = new HashSet<V>();
+          set.add(entry.getValue());
+          setMultiMap.put(entry.getKey(), Collections.unmodifiableSet(set));
+        }
+        return Collections.unmodifiableMap(setMultiMap);
+      }
+
+      private InputFile getInputFile(FileIO.ReadableFile file) throws 
IOException {
+        if (!file.getMetadata().isReadSeekEfficient()) {
+          throw new RuntimeException(
+              String.format("File has to be seekable: %s", 
file.getMetadata().resourceId()));
+        }
+        return new BeamParquetInputFile(file.openSeekable());
+      }
+
+      @ProcessElement
+      public void processElement(
+          @Element FileIO.ReadableFile file,
+          RestrictionTracker<OffsetRange, Long> tracker,
+          OutputReceiver<GenericRecord> outputReceiver)
+          throws Exception {
+        ReadSupport<GenericRecord> readSupport;
+        InputFile inputFile = getInputFile(file);
+        Configuration conf = setConf();
+        GenericData model = null;
+        if (modelClass != null) {
+          model = (GenericData) modelClass.getMethod("get").invoke(null);
+        }
+        readSupport = new AvroReadSupport<GenericRecord>(model);
+        ParquetReadOptions options = HadoopReadOptions.builder(conf).build();
+        ParquetFileReader reader = ParquetFileReader.open(inputFile, options);
+        Filter filter = checkNotNull(options.getRecordFilter(), "filter");
+        conf = ((HadoopReadOptions) options).getConf();
+        for (String property : options.getPropertyNames()) {
+          conf.set(property, options.getProperty(property));
+        }
+        FileMetaData parquetFileMetadata = 
reader.getFooter().getFileMetaData();
+        MessageType fileSchema = parquetFileMetadata.getSchema();
+        Map<String, String> fileMetadata = 
parquetFileMetadata.getKeyValueMetaData();
+
+        ReadSupport.ReadContext readContext =
+            readSupport.init(new InitContext(conf, 
toSetMultiMap(fileMetadata), fileSchema));
+        ColumnIOFactory columnIOFactory = new 
ColumnIOFactory(parquetFileMetadata.getCreatedBy());
+        MessageType requestedSchema = readContext.getRequestedSchema();
+        RecordMaterializer<GenericRecord> recordConverter =
+            readSupport.prepareForRead(conf, fileMetadata, fileSchema, 
readContext);
+        boolean strictTypeChecking = options.isEnabled(STRICT_TYPE_CHECKING, 
true);
+        boolean filterRecords = options.useRecordFilter();
+        reader.setRequestedSchema(requestedSchema);
+        MessageColumnIO columnIO =
+            columnIOFactory.getColumnIO(requestedSchema, fileSchema, 
strictTypeChecking);
+        long currentBlock = tracker.currentRestriction().getFrom();
+        for (int i = 0; i < currentBlock; i++) {
+          reader.skipNextRowGroup();
+        }
+        while (tracker.tryClaim(currentBlock)) {
+
+          LOG.info("reading block" + currentBlock);
+          PageReadStore pages = reader.readNextRowGroup();
+          currentBlock += 1;
+          RecordReader<GenericRecord> recordReader =
+              columnIO.getRecordReader(
+                  pages, recordConverter, filterRecords ? filter : 
FilterCompat.NOOP);
+          long currentRow = 0;
+          long totalRows = pages.getRowCount();
+          while (currentRow < totalRows) {
+            try {
+              GenericRecord record;
+              currentRow += 1;
+              try {
+                record = recordReader.read();
+              } catch (RecordMaterializer.RecordMaterializationException e) {
+                LOG.debug("skipping a corrupt record");
+                continue;
+              }
+              if (recordReader.shouldSkipCurrentRecord()) {
+                // this record is being filtered via the filter2 package
+                LOG.debug("skipping record");
+                continue;
+              }
+              if (record == null) {
+                // only happens with FilteredRecordReader at end of block
+                LOG.debug("filtered record reader reached end of block");
+                break;
+              }
+              if (tracker instanceof BlockTracker) {
+                ((BlockTracker) tracker).makeProgress();
+              }
+              outputReceiver.output(record);
+            } catch (RuntimeException e) {
+              throw new ParquetDecodingException(
+                  format(
+                      "Can not read value at %d in block %d in file %s",
+                      currentRow, currentBlock, file.toString()),
+                  e);
+            }
+          }
+          LOG.info("finish read " + currentRow + " rows");
+        }
+      }
+
+      private Configuration setConf() throws Exception {
+        Configuration conf = new Configuration();
+        GenericData model = null;
+        if (modelClass != null) {
+          model = (GenericData) modelClass.getMethod("get").invoke(null);
+        }
+        if (model != null
+            && (model.getClass() == GenericData.class || model.getClass() == 
SpecificData.class)) {
+          conf.setBoolean(AvroReadSupport.AVRO_COMPATIBILITY, true);
+        } else {
+          conf.setBoolean(AvroReadSupport.AVRO_COMPATIBILITY, false);
+        }
+        return conf;
+      }
+
+      @GetInitialRestriction
+      public OffsetRange getInitialRestriction(@Element FileIO.ReadableFile 
file) throws Exception {
+        InputFile inputFile = getInputFile(file);
+        Configuration conf = setConf();
+        ParquetReadOptions options = HadoopReadOptions.builder(conf).build();
+        ParquetFileReader reader = ParquetFileReader.open(inputFile, options);
+        return new OffsetRange(0, reader.getRowGroups().size());
+      }
+
+      @SplitRestriction
+      public void split(
+          @Restriction OffsetRange restriction,
+          OutputReceiver<OffsetRange> out,
+          @Element FileIO.ReadableFile file)
+          throws Exception {
+        long start = restriction.getFrom();
+        long end = restriction.getFrom();
+        InputFile inputFile = getInputFile(file);
+        Configuration conf = setConf();
+        ParquetReadOptions options = HadoopReadOptions.builder(conf).build();
+        ParquetFileReader reader = ParquetFileReader.open(inputFile, options);
+        List<BlockMetaData> rowGroups = reader.getRowGroups();
+        long totalSize = 0;
+        for (long i = restriction.getFrom(); i < restriction.getTo(); i++) {
+          totalSize += rowGroups.get((int) i).getTotalByteSize();
+          end += 1;
+          if (totalSize > SPLIT_LIMIT) {
+            start = end;
+            totalSize = 0;
+            out.output(new OffsetRange(start, end));
+          }
+        }
+        if (totalSize != 0) {
+          out.output(new OffsetRange(start, end));
+        }
+      }
+
+      @NewTracker
+      public RestrictionTracker<OffsetRange, Long> newTracker(
+          @Restriction OffsetRange restriction, @Element FileIO.ReadableFile 
file)
+          throws Exception {
+        return new BlockTracker(
+            restriction, (long) getSize(file, restriction), 
getRecordCount(file, restriction));
+      }
+
+      @GetRestrictionCoder
+      public OffsetRange.Coder getRestrictionCoder() {
+        return new OffsetRange.Coder();
+      }
+
+      public long getRecordCount(
+          @Element FileIO.ReadableFile file, @Restriction OffsetRange 
restriction)
+          throws Exception {
+        InputFile inputFile = getInputFile(file);
+        Configuration conf = setConf();
+        ParquetReadOptions options = HadoopReadOptions.builder(conf).build();
+        ParquetFileReader reader = ParquetFileReader.open(inputFile, options);
+        if (restriction == null) {
+          return 0;
+        } else {
+          long recordCount = 0;
+          for (long i = restriction.getFrom(); i < restriction.getTo(); i++) {
+            recordCount += reader.getRowGroups().get((int) i).getRowCount();
+          }
+          return recordCount;
+        }
+      }
+
+      public double getSize(@Element FileIO.ReadableFile file, @Restriction 
OffsetRange restriction)
+          throws Exception {
+        InputFile inputFile = getInputFile(file);
+        Configuration conf = setConf();
+        ParquetReadOptions options = HadoopReadOptions.builder(conf).build();
+        ParquetFileReader reader = ParquetFileReader.open(inputFile, options);
+        if (restriction == null) {
+          return 0;
+        } else {
+          double size = 0;
+          for (long i = restriction.getFrom(); i < restriction.getTo(); i++) {
+            size += reader.getRowGroups().get((int) i).getTotalByteSize();

Review comment:
       It should be a O(N) operation while N is the number of blocks.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


Reply via email to