Github user fhueske commented on a diff in the pull request: https://github.com/apache/flink/pull/1517#discussion_r52459974 --- Diff: flink-runtime/src/main/java/org/apache/flink/runtime/operators/hash/ReduceHashTable.java --- @@ -0,0 +1,1014 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.operators.hash; + +import org.apache.flink.api.common.functions.ReduceFunction; +import org.apache.flink.api.common.typeutils.SameTypePairComparator; +import org.apache.flink.api.common.typeutils.TypeComparator; +import org.apache.flink.api.common.typeutils.TypePairComparator; +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.core.memory.DataInputView; +import org.apache.flink.core.memory.MemorySegment; +import org.apache.flink.runtime.io.disk.RandomAccessInputView; +import org.apache.flink.runtime.memory.AbstractPagedOutputView; +import org.apache.flink.runtime.util.MathUtils; +import org.apache.flink.util.Collector; +import org.apache.flink.util.MutableObjectIterator; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.EOFException; +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +/** + * This hash table supports updating elements, and it also has processRecordWithReduce, + * which makes one reduce step with the given record. + * + * The memory is divided into three areas: + * - Bucket area: they contain bucket heads: + * an 8 byte pointer to the first link of a linked list in the record area + * - Record area: this contains the actual data in linked list elements. A linked list element starts + * with an 8 byte pointer to the next element, and then the record follows. + * - Staging area: This is a small, temporary storage area for writing updated records. This is needed, + * because before serializing a record, there is no way to know in advance how large will it be. + * Therefore, we can't serialize directly into the record area when we are doing an update, because + * if it turns out to be larger then the old record, then it would override some other record + * that happens to be after the old one in memory. The solution is to serialize to the staging area first, + * and then copy it to the place of the original if it has the same size, otherwise allocate a new linked + * list element at the end of the record area, and mark the old one as abandoned. This creates "holes" in + * the record area, so compactions are eventually needed. + * + * Compaction happens by deleting everything in the bucket area, and then reinserting all elements. + * The reinsertion happens by forgetting the structure (the linked lists) of the record area, and reading it + * sequentially, and inserting all non-abandoned records, starting from the beginning of the record area. + * Note, that insertions never override a record that have not been read by the reinsertion sweep, because + * both the insertions and readings happen sequentially in the record area, and the insertions obviously + * never overtake the reading sweep. + * + * Note: we have to abandon the old linked list element even when the updated record has a smaller size + * than the original, because otherwise we wouldn't know where the next record starts during a reinsertion + * sweep. + * + * The number of buckets depends on how large are the records. The serializer might be able to tell us this, + * so in this case, we will calculate the number of buckets upfront, and won't do resizes. + * If the serializer doesn't know the size, then we start with a small number of buckets, and do resizes as more + * elements are inserted than the number of buckets. + * + * The number of memory segments given to the staging area is usually one, because it just needs to hold + * one record. + * + * Note: For hashing, we need to use MathUtils.hash because of its avalanche property, so that + * changing only some high bits of the original value shouldn't leave the lower bits of the hash unaffected. + * This is because when choosing the bucket for a record, we mask only the + * lower bits (see numBucketsMask). Lots of collisions would occur when, for example, + * the original value that is hashed is some bitset, where lots of different values + * that are different only in the higher bits will actually occur. + */ + +public class ReduceHashTable<T> extends AbstractMutableHashTable<T> { + + private static final Logger LOG = LoggerFactory.getLogger(ReduceHashTable.class); + + /** The minimum number of memory segments ReduceHashTable needs to be supplied with in order to work. */ + private static final int MIN_NUM_MEMORY_SEGMENTS = 3; + + /** The last link in the linked lists will have this as next pointer. */ + private static final long END_OF_LIST = -1; + + /** + * The next pointer of a link will have this value, if it is not part of the linked list. + * (This can happen because the record couldn't be updated in-place due to a size change.) + * Note: the record that is in the link should still be readable, in order to be possible to determine + * the size of the place (see EntryIterator). + * Note: the last record in the record area can't be abandoned. (EntryIterator makes use of this fact.) + */ + private static final long ABANDONED_RECORD = -2; + + /** This value means that prevElemPtr is "pointing to the bucket head", and not into the record segments. */ + private static final long INVALID_PREV_POINTER = -3; + + private static final long RECORD_OFFSET_IN_LINK = 8; + + + /** this is used by processRecordWithReduce */ + private final ReduceFunction<T> reducer; + + /** emit() sends data to outputCollector */ + private final Collector<T> outputCollector; + + private final boolean objectReuseEnabled; + + /** + * This initially contains all the memory we have, and then segments + * are taken from it by bucketSegments, recordArea, and stagingSegments. + */ + private final ArrayList<MemorySegment> freeMemorySegments; + + private final int numAllMemorySegments; + + private final int segmentSize; + + /** + * These will contain the bucket heads. + * The bucket heads are pointers to the linked lists containing the actual records. + */ + private MemorySegment[] bucketSegments; + + private static final int bucketSize = 8, bucketSizeBits = 3; + + private int numBuckets; + private int numBucketsMask; + private final int numBucketsPerSegment, numBucketsPerSegmentBits, numBucketsPerSegmentMask; + + /** + * The segments where the actual data is stored. + */ + private final RecordArea recordArea; + + /** + * Segments for the staging area. + * (It should contain at most one record at all times.) + */ + private final ArrayList<MemorySegment> stagingSegments; + private final RandomAccessInputView stagingSegmentsInView; + private final StagingOutputView stagingSegmentsOutView; + + private T reuse; + + /** This is the internal prober that insertOrReplaceRecord and processRecordWithReduce use. */ + private final HashTableProber<T> prober; + + /** The number of elements currently held by the table. */ + private long numElements = 0; + + /** The number of bytes wasted by updates that couldn't overwrite the old record. */ + private long holes = 0; + + /** + * If the serializer knows the size of the records, then we can calculate the optimal number of buckets + * upfront, so we don't need resizes. + */ + private boolean enableResize; + + + /** + * This constructor is for the case when will only call those operations that are also + * present on CompactingHashTable. + */ + public ReduceHashTable(TypeSerializer<T> serializer, TypeComparator<T> comparator, List<MemorySegment> memory) { + this(serializer, comparator, memory, null, null, false); + } + + public ReduceHashTable(TypeSerializer<T> serializer, TypeComparator<T> comparator, List<MemorySegment> memory, + ReduceFunction<T> reducer, Collector<T> outputCollector, boolean objectReuseEnabled) { + super(serializer, comparator); + this.reducer = reducer; + this.numAllMemorySegments = memory.size(); + this.freeMemorySegments = new ArrayList<>(memory); + this.outputCollector = outputCollector; + this.objectReuseEnabled = objectReuseEnabled; + + // some sanity checks first + if (freeMemorySegments.size() < MIN_NUM_MEMORY_SEGMENTS) { + throw new IllegalArgumentException("Too few memory segments provided. ReduceHashTable needs at least " + + MIN_NUM_MEMORY_SEGMENTS + " memory segments."); + } + + // Get the size of the first memory segment and record it. All further buffers must have the same size. + // the size must also be a power of 2 + segmentSize = freeMemorySegments.get(0).size(); + if ( (segmentSize & segmentSize - 1) != 0) { + throw new IllegalArgumentException("Hash Table requires buffers whose size is a power of 2."); + } + + this.numBucketsPerSegment = segmentSize / bucketSize; + this.numBucketsPerSegmentBits = MathUtils.log2strict(this.numBucketsPerSegment); + this.numBucketsPerSegmentMask = (1 << this.numBucketsPerSegmentBits) - 1; + + recordArea = new RecordArea(segmentSize); + + stagingSegments = new ArrayList<>(); + stagingSegmentsInView = new RandomAccessInputView(stagingSegments, segmentSize); + stagingSegmentsOutView = new StagingOutputView(stagingSegments, segmentSize); + + prober = new HashTableProber<>(buildSideComparator, new SameTypePairComparator<>(buildSideComparator)); + + enableResize = buildSideSerializer.getLength() == -1; + } + + private void open(int numBucketSegments) { + synchronized (stateLock) { + if (!closed) { + throw new IllegalStateException("currently not closed."); + } + closed = false; + } + + allocateBucketSegments(numBucketSegments); + + stagingSegments.add(allocateSegment()); + + reuse = buildSideSerializer.createInstance(); + } + + /** + * Initialize the hash table + */ + @Override + public void open() { + open(calcInitialNumBucketSegments()); + } + + @Override + public void close() { + // make sure that we close only once + synchronized (stateLock) { + if (closed) { + return; + } + closed = true; + } + + LOG.debug("Closing ReduceHashTable and releasing resources."); + + releaseBucketSegments(); + + recordArea.giveBackSegments(); + + freeMemorySegments.addAll(stagingSegments); + stagingSegments.clear(); + + numElements = 0; + holes = 0; + } + + @Override + public void abort() { + // ReduceHashTable doesn't have closed loops like CompactingHashTable.buildTableWithUniqueKey. + } + + @Override + public List<MemorySegment> getFreeMemory() { + if (!this.closed) { + throw new IllegalStateException("Cannot return memory while ReduceHashTable is open."); + } + + return freeMemorySegments; + } + + private int calcInitialNumBucketSegments() { + int recordLength = buildSideSerializer.getLength(); + double fraction; + if (recordLength == -1) { + // It seems that resizing is quite efficient, so we can err here on the too few bucket segments side. + // Even with small records, we lose only ~15% speed. + fraction = 0.1; + } else { + fraction = 8.0 / (16 + recordLength); + // note: enableResize is false in this case, so no resizing will happen + } + + int ret = Math.max(1, MathUtils.roundDownToPowerOf2((int)(numAllMemorySegments * fraction))); + + // We can't handle more than Integer.MAX_VALUE buckets (eg. because hash functions return int) + if ((long)ret * numBucketsPerSegment > Integer.MAX_VALUE) { + ret = MathUtils.roundDownToPowerOf2(Integer.MAX_VALUE / numBucketsPerSegment); + } + return ret; + } + + private void allocateBucketSegments(int numBucketSegments) { + if (numBucketSegments < 1) { + throw new RuntimeException("Bug in ReduceHashTable"); + } + + bucketSegments = new MemorySegment[numBucketSegments]; + for(int i = 0; i < bucketSegments.length; i++) { + bucketSegments[i] = allocateSegment(); + if (bucketSegments[i] == null) { + throw new RuntimeException("Bug in ReduceHashTable: allocateBucketSegments should be " + + "called in a way that there is enough free memory."); + } + // Init all pointers in all buckets to END_OF_LIST + for(int j = 0; j < numBucketsPerSegment; j++) { + bucketSegments[i].putLong(j << bucketSizeBits, END_OF_LIST); + } + } + numBuckets = numBucketSegments * numBucketsPerSegment; + numBucketsMask = (1 << MathUtils.log2strict(numBuckets)) - 1; + } + + private void releaseBucketSegments() { + freeMemorySegments.addAll(Arrays.asList(bucketSegments)); + bucketSegments = null; + } + + private MemorySegment allocateSegment() { + int s = freeMemorySegments.size(); + if (s > 0) { + return freeMemorySegments.remove(s - 1); + } else { + return null; + } + } + + /** + * Searches the hash table for the record with matching key, and updates it (making one reduce step) if found, + * otherwise inserts a new entry. + * + * (If there are multiple entries with the same key, then it will update one of them.) + * + * @param record The record to be processed. + */ + public void processRecordWithReduce(T record) throws Exception { + if (closed) { + return; + } + + T match = prober.getMatchFor(record, reuse); + if (match == null) { + prober.insertAfterNoMatch(record); + } else { + // do the reduce step + T res = reducer.reduce(match, record); + + // We have given reuse to the reducer UDF, so create new one if object reuse is disabled + if (!objectReuseEnabled) { + reuse = buildSideSerializer.createInstance(); + } + + prober.updateMatch(res); + } + } + + /** + * Searches the hash table for a record with the given key. + * If it is found, then it is overridden with the specified record. + * Otherwise, the specified record is inserted. + * @param record The record to insert or to replace with. + * @throws IOException (EOFException specifically, if memory ran out) + */ + @Override + public void insertOrReplaceRecord(T record) throws IOException { + if (closed) { + return; + } + + T match = prober.getMatchFor(record, reuse); + if (match == null) { + prober.insertAfterNoMatch(record); + } else { + prober.updateMatch(record); + } + } + + /** + * Inserts the given record into the hash table. + * Note: this method doesn't care about whether a record with the same key is already present. + * @param record The record to insert. + * @throws IOException (EOFException specifically, if memory ran out) + */ + @Override + public void insert(T record) throws IOException { + if (closed) { + return; + } + + final int hashCode = MathUtils.hash(buildSideComparator.hash(record)); + final int bucket = hashCode & numBucketsMask; + final int bucketSegmentIndex = bucket >>> numBucketsPerSegmentBits; // which segment contains the bucket + final MemorySegment bucketSegment = bucketSegments[bucketSegmentIndex]; + final int bucketOffset = (bucket & numBucketsPerSegmentMask) << bucketSizeBits; // offset of the bucket in the segment + final long firstPointer = bucketSegment.getLong(bucketOffset); + + try { + final long newFirstPointer = recordArea.appendPointerAndRecord(firstPointer, record); + bucketSegment.putLong(bucketOffset, newFirstPointer); + } catch (EOFException ex) { + compactOrThrow(); + insert(record); + return; + } + + numElements++; + resizeTableIfNecessary(); + } + + private void resizeTableIfNecessary() throws IOException { + if (enableResize && numElements > numBuckets) { + final long newNumBucketSegments = 2L * bucketSegments.length; + // Checks: + // - we can't handle more than Integer.MAX_VALUE buckets + // - don't take more memory than the free memory we have left + // - the buckets shouldn't occupy more than half of all our memory + if (newNumBucketSegments * numBucketsPerSegment < Integer.MAX_VALUE && + newNumBucketSegments - bucketSegments.length < freeMemorySegments.size() && + newNumBucketSegments < numAllMemorySegments / 2) { + // do the resize + rebuild(newNumBucketSegments); + } + } + } + + /** + * WARNING: Doing any other operation on the table invalidates the iterator! (Even + * using getMatchFor of a prober!) + */ + public final class EntryIterator implements MutableObjectIterator<T> { + + private final long endPosition; + + public EntryIterator() { + endPosition = recordArea.getAppendPosition(); + if (endPosition == 0) { + return; + } + recordArea.setReadPosition(0); + } + + @Override + public T next(T reuse) throws IOException { + if (endPosition != 0 && recordArea.getReadPosition() < endPosition) { + // Loop until we find a non-abandoned record. + // Note: the last record in the record area can't be abandoned. + while (true) { + final boolean isAbandoned = recordArea.readLong() == ABANDONED_RECORD; + reuse = recordArea.readRecord(reuse); + if (!isAbandoned) { + return reuse; + } + } + } else { + return null; + } + } + + @Override + public T next() throws IOException { + return next(buildSideSerializer.createInstance()); + } + } + + /** + * Returns an iterator that can be used to iterate over all the elements in the table. + * WARNING: Doing any other operation on the table invalidates the iterator! (Even + * using getMatchFor of a prober!) + * @return the iterator + */ + @Override + public EntryIterator getEntryIterator() { + return new EntryIterator(); + } + + /** + * This function reinitializes the bucket segments, + * reads all records from the record segments (sequentially, without using the pointers or the buckets), + * and rebuilds the hash table. + */ + private void rebuild() throws IOException { + rebuild(bucketSegments.length); + } + + /** Same as above, but the number of bucket segments of the new table can be specified. */ + private void rebuild(long newNumBucketSegments) throws IOException { + // Get new bucket segments + releaseBucketSegments(); + allocateBucketSegments((int)newNumBucketSegments); + + T record = buildSideSerializer.createInstance(); + try { + EntryIterator iter = getEntryIterator(); + recordArea.resetAppendPosition(); + recordArea.setWritePosition(0); + while ((record = iter.next(record)) != null) { + final int hashCode = MathUtils.hash(buildSideComparator.hash(record)); + final int bucket = hashCode & numBucketsMask; + final int bucketSegmentIndex = bucket >>> numBucketsPerSegmentBits; // which segment contains the bucket + final MemorySegment bucketSegment = bucketSegments[bucketSegmentIndex]; + final int bucketOffset = (bucket & numBucketsPerSegmentMask) << bucketSizeBits; // offset of the bucket in the segment + final long firstPointer = bucketSegment.getLong(bucketOffset); + + long ptrToAppended = recordArea.noSeekAppendPointerAndRecord(firstPointer, record); + bucketSegment.putLong(bucketOffset, ptrToAppended); + } + recordArea.freeSegmentsAfterAppendPosition(); + holes = 0; + + } catch (EOFException ex) { + throw new RuntimeException("Bug in ReduceHashTable: we shouldn't get out of memory during a rebuild, " + + "because we aren't allocating any new memory."); + } + } + + /** + * Emits all elements currently held by the table to the collector, + * and resets the table. The table will have the same number of buckets + * as before the reset, to avoid doing resizes again. + */ + public void emitAndReset() throws IOException { + final int oldNumBucketSegments = bucketSegments.length; + emit(); + close(); + open(oldNumBucketSegments); + } + + /** + * Emits all elements currently held by the table to the collector. + */ + public void emit() throws IOException { + T record = buildSideSerializer.createInstance(); + EntryIterator iter = getEntryIterator(); + while ((record = iter.next(record)) != null) { + outputCollector.collect(record); + if (!objectReuseEnabled) { + record = buildSideSerializer.createInstance(); + } + } + } + + /** + * If there is wasted space due to updates records not fitting in their old places, then do a compaction. + * Else, throw EOFException to indicate that memory ran out. + * @throws IOException + */ + private void compactOrThrow() throws IOException { + if (holes > 0) { + rebuild(); + } else { + throw new EOFException("ReduceHashTable memory ran out. " + getMemoryConsumptionString()); + } + } + + /** + * @return String containing a summary of the memory consumption for error messages + */ + private String getMemoryConsumptionString() { + return "ReduceHashTable memory stats:\n" + + "Total memory: " + numAllMemorySegments * segmentSize + "\n" + + "Bucket area: " + numBuckets * 8 + "\n" + + "Record area: " + recordArea.getTotalSize() + "\n" + + "Staging area: " + stagingSegments.size() * segmentSize + "\n" + + "Num of elements: " + numElements + "\n" + + "Holes total size: " + holes; + } + + /** + * This class encapsulates the memory segments that belong to the record area. It + * - can append a record + * - can overwrite a record at an arbitrary position (WARNING: the new record must have the same size + * as the old one) + * - can be rewritten by calling resetAppendPosition + * - takes memory from ReduceHashTable.freeMemorySegments on append + */ + private final class RecordArea + { + private final ArrayList<MemorySegment> segments = new ArrayList<>(); + + private final OutputView outView; + private final RandomAccessInputView inView; + + private final int segmentSizeBits; + private final int segmentSizeMask; + + private long appendPosition = 0; + + + private final class OutputView extends AbstractPagedOutputView { --- End diff -- Can you use `SimpleCollectingOutputView` instead of an own implementation and let `ReduceHashTable` implement `MemorySegmentSource`?
--- If your project is set up for it, you can reply to this email and have your reply appear on GitHub as well. If your project does not have this feature enabled and wishes so, or if the feature is enabled but not working, please contact infrastructure at infrastruct...@apache.org or file a JIRA ticket with INFRA. ---