Github user ggevay commented on a diff in the pull request: https://github.com/apache/flink/pull/3511#discussion_r140663639 --- Diff: flink-runtime/src/main/java/org/apache/flink/runtime/codegeneration/SorterTemplateModel.java --- @@ -0,0 +1,424 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.codegeneration; + +import org.apache.flink.api.common.typeutils.TypeComparator; +import org.apache.flink.runtime.operators.sort.NormalizedKeySorter; + +import java.nio.ByteOrder; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.Map; + +/** + * {@link SorterTemplateModel} is a class that implements code generation logic for a given + * {@link TypeComparator}. + * + * <p>The swap and compare methods in {@link NormalizedKeySorter} work on a sequence of bytes. + * We speed up these operations by splitting this sequence of bytes into chunks that can + * be handled by primitive operations such as Integer and Long operations.</p> + */ +class SorterTemplateModel { + + // ------------------------------------------------------------------------ + // Constants + // ------------------------------------------------------------------------ + + static final String TEMPLATE_NAME = "sorter.ftlh"; + + /** We don't split to chunks above this size. */ + private static final int SPLITTING_THRESHOLD = 32; + + /** + * POSSIBLE_CHUNK_SIZES must be in descending order, + * because methods that using it are using greedy approach. + */ + private static final Integer[] POSSIBLE_CHUNK_SIZES = {8, 4, 2, 1}; + + /** Mapping from chunk sizes to primitive operators. */ + private static final HashMap<Integer, String> byteOperatorMapping = new HashMap<Integer, String>(){ + { + put(8, "Long"); + put(4, "Int"); + put(2, "Short"); + put(1, "Byte"); + } + }; + + // ------------------------------------------------------------------------ + // Attributes + // ------------------------------------------------------------------------ + + private final TypeComparator typeComparator; + + /** + * Sizes of the chunks. Empty, if we are not splitting to chunks. (See calculateChunks()) + */ + private final ArrayList<Integer> primitiveChunks; + + private final String sorterName; + + /** + * Shows whether the order of records can be completely determined by the normalized + * sorting key, or the sorter has to also deserialize records if their keys are equal to + * really confirm the order. + */ + private final boolean normalizedKeyFullyDetermines; + + /** + * Constructor. + * @param typeComparator + * The type information of underlying data + */ + SorterTemplateModel(TypeComparator typeComparator){ + this.typeComparator = typeComparator; + + // number of bytes of the sorting key + int numKeyBytes; + + // compute no. bytes for sorting records and check whether these bytes are just a prefix or not. + if (this.typeComparator.supportsNormalizedKey()) { + // compute the max normalized key length + int numPartialKeys; + try { + numPartialKeys = this.typeComparator.getFlatComparators().length; + } catch (Throwable t) { + numPartialKeys = 1; + } + + int maxLen = Math.min(NormalizedKeySorter.DEFAULT_MAX_NORMALIZED_KEY_LEN, NormalizedKeySorter.MAX_NORMALIZED_KEY_LEN_PER_ELEMENT * numPartialKeys); + + numKeyBytes = Math.min(this.typeComparator.getNormalizeKeyLen(), maxLen); + this.normalizedKeyFullyDetermines = !this.typeComparator.isNormalizedKeyPrefixOnly(numKeyBytes); + } + else { + numKeyBytes = 0; + this.normalizedKeyFullyDetermines = false; + } + + this.primitiveChunks = calculateChunks(numKeyBytes); + + this.sorterName = generateCodeFilename(this.primitiveChunks, this.normalizedKeyFullyDetermines); + } + + // ------------------------------------------------------------------------ + // Public Methods + // ------------------------------------------------------------------------ + + /** + * Generate suitable sequence of operators for creating custom NormalizedKeySorter. + * @return map of procedures and corresponding code + */ + Map<String, String> getTemplateVariables() { + + Map<String, String> templateVariables = new HashMap<>(); + + templateVariables.put("name", this.sorterName); + + String swapProcedures = generateSwapProcedures(); + String writeProcedures = generateWriteProcedures(); + String compareProcedures = generateCompareProcedures(); + + templateVariables.put("writeProcedures", writeProcedures); + templateVariables.put("swapProcedures", swapProcedures); + templateVariables.put("compareProcedures", compareProcedures); + + return templateVariables; + } + + /** + * Getter for sorterName (generated in the constructor). + * @return name of the sorter + */ + String getSorterName(){ + return this.sorterName; + } + + // ------------------------------------------------------------------------ + // Protected Methods + // ------------------------------------------------------------------------ + + /** + * Getter for primitiveChunks. + * this method is for testing purposes + */ + ArrayList<Integer> getPrimitiveChunks(){ + return primitiveChunks; + } + + // ------------------------------------------------------------------------ + // Private Methods + // ------------------------------------------------------------------------ + + /** + * Given no. of bytes, break it into chunks that can be handled by + * primitive operations (e.g., integer or long operations) + * @return ArrayList of chunk sizes + */ + private ArrayList<Integer> calculateChunks(int numKeyBytes){ + ArrayList<Integer> chunks = new ArrayList<>(); + + // if no. of bytes is too large, we don't split + if (numKeyBytes > SPLITTING_THRESHOLD) { + return chunks; + } + + // also include the offset because of the pointer + numKeyBytes += NormalizedKeySorter.OFFSET_LEN; + + // greedy finding of chunk sizes + int i = 0; + while (numKeyBytes > 0) { + int bytes = POSSIBLE_CHUNK_SIZES[i]; + if (bytes <= numKeyBytes) { + chunks.add(bytes); + numKeyBytes -= bytes; + } else { + i++; + } + } + + // generateCompareProcedures and generateWriteProcedures skip the + // first 8 bytes, because it contains the pointer. + // They do this by skipping the first entry of primitiveChunks, because that + // should always be 8 in this case. + if (!(NormalizedKeySorter.OFFSET_LEN == 8 && chunks.get(0).equals(8))) { + throw new RuntimeException("Bug: Incorrect OFFSET_LEN or primitiveChunks"); + } + + return chunks; + } + + /** + * Based on primitiveChunks variable, generate the most suitable operators + * for swapping function. + * + * @return code used in the swap method + */ + private String generateSwapProcedures(){ + /* Example generated code, for 20 bytes (8+8+4): + + long temp1 = segI.getLong(segmentOffsetI); + long temp2 = segI.getLong(segmentOffsetI+8); + int temp3 = segI.getInt(segmentOffsetI+16); + + segI.putLong(segmentOffsetI, segJ.getLong(segmentOffsetJ)); + segI.putLong(segmentOffsetI+8, segJ.getLong(segmentOffsetJ+8)); + segI.putInt(segmentOffsetI+16, segJ.getInt(segmentOffsetJ+16)); + + segJ.putLong(segmentOffsetJ, temp1); + segJ.putLong(segmentOffsetJ+8, temp2); + segJ.putInt(segmentOffsetJ+16, temp3); + */ + + String procedures = ""; + + if (this.primitiveChunks.size() > 0) { + StringBuilder temporaryString = new StringBuilder(); + StringBuilder firstSegmentString = new StringBuilder(); + StringBuilder secondSegmentString = new StringBuilder(); + + int accOffset = 0; + for (int i = 0; i < primitiveChunks.size(); i++){ + int numberByte = primitiveChunks.get(i); + int varIndex = i + 1; + + String primitiveClass = byteOperatorMapping.get(numberByte); + String primitiveType = primitiveClass.toLowerCase(); + + String offsetString = ""; + if (i > 0) { + accOffset += primitiveChunks.get(i - 1); + offsetString = "+" + accOffset; + } + + temporaryString.append(String.format("%s temp%d = segI.get%s(segmentOffsetI%s);\n", + primitiveType, varIndex, primitiveClass, offsetString)); + + firstSegmentString.append(String.format("segI.put%s(segmentOffsetI%s, segJ.get%s(segmentOffsetJ%s));\n", + primitiveClass, offsetString, primitiveClass, offsetString)); + + secondSegmentString.append(String.format("segJ.put%s(segmentOffsetJ%s, temp%d);\n", + primitiveClass, offsetString, varIndex)); + + } + + procedures = temporaryString.toString() + + "\n" + firstSegmentString.toString() + + "\n" + secondSegmentString.toString(); + } else { + procedures = "segI.swapBytes(this.swapBuffer, segJ, segmentOffsetI, segmentOffsetJ, this.indexEntrySize);"; + } + + return procedures; + } + + /** + * Based on primitiveChunks variable, generate reverse byte operators for little endian machine + * for writing a record to MemorySegment, such that later during comparison + * we can directly use native byte order to do unsigned comparison. + * + * @return code used in the write method + */ + private String generateWriteProcedures(){ + /* Example generated code, for 12 bytes (8+4): + + long temp1 = Long.reverseBytes(this.currentSortIndexSegment.getLong(this.currentSortIndexOffset+8)); + this.currentSortIndexSegment.putLong(this.currentSortIndexOffset + 8, temp1); + int temp2 = Integer.reverseBytes(this.currentSortIndexSegment.getInt(this.currentSortIndexOffset+16)); + this.currentSortIndexSegment.putInt(this.currentSortIndexOffset + 16, temp2); + */ + + StringBuilder procedures = new StringBuilder(); + // skip the first chunk, which is the pointer before the key + if (primitiveChunks.size() > 1 && ByteOrder.nativeOrder() == ByteOrder.LITTLE_ENDIAN) { + int offset = 0; + // starts from 1 because of skipping the first chunk + for (int i = 1; i < primitiveChunks.size(); i++){ + int noBytes = primitiveChunks.get(i); + if (noBytes == 1){ --- End diff -- 41fd844173c5e6c8f8935df5b6a7762ae45596b6
---