Github user ggevay commented on a diff in the pull request:

    https://github.com/apache/flink/pull/3511#discussion_r140663639
  
    --- Diff: 
flink-runtime/src/main/java/org/apache/flink/runtime/codegeneration/SorterTemplateModel.java
 ---
    @@ -0,0 +1,424 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one
    + * or more contributor license agreements.  See the NOTICE file
    + * distributed with this work for additional information
    + * regarding copyright ownership.  The ASF licenses this file
    + * to you under the Apache License, Version 2.0 (the
    + * "License"); you may not use this file except in compliance
    + * with the License.  You may obtain a copy of the License at
    + *
    + *     http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +
    +package org.apache.flink.runtime.codegeneration;
    +
    +import org.apache.flink.api.common.typeutils.TypeComparator;
    +import org.apache.flink.runtime.operators.sort.NormalizedKeySorter;
    +
    +import java.nio.ByteOrder;
    +import java.util.ArrayList;
    +import java.util.HashMap;
    +import java.util.Map;
    +
    +/**
    + * {@link SorterTemplateModel} is a class that implements code generation 
logic for a given
    + * {@link TypeComparator}.
    + *
    + * <p>The swap and compare methods in {@link NormalizedKeySorter} work on 
a sequence of bytes.
    + * We speed up these operations by splitting this sequence of bytes into 
chunks that can
    + * be handled by primitive operations such as Integer and Long 
operations.</p>
    + */
    +class SorterTemplateModel {
    +
    +   // 
------------------------------------------------------------------------
    +   //                                   Constants
    +   // 
------------------------------------------------------------------------
    +
    +   static final String TEMPLATE_NAME = "sorter.ftlh";
    +
    +   /** We don't split to chunks above this size. */
    +   private static final int SPLITTING_THRESHOLD = 32;
    +
    +   /**
    +    * POSSIBLE_CHUNK_SIZES must be in descending order,
    +    * because methods that using it are using greedy approach.
    +    */
    +   private static final Integer[] POSSIBLE_CHUNK_SIZES = {8, 4, 2, 1};
    +
    +   /** Mapping from chunk sizes to primitive operators. */
    +   private static final HashMap<Integer, String> byteOperatorMapping = new 
HashMap<Integer, String>(){
    +           {
    +                   put(8, "Long");
    +                   put(4, "Int");
    +                   put(2, "Short");
    +                   put(1, "Byte");
    +           }
    +   };
    +
    +   // 
------------------------------------------------------------------------
    +   //                                   Attributes
    +   // 
------------------------------------------------------------------------
    +
    +   private final TypeComparator typeComparator;
    +
    +   /**
    +    * Sizes of the chunks. Empty, if we are not splitting to chunks. (See 
calculateChunks())
    +    */
    +   private final ArrayList<Integer> primitiveChunks;
    +
    +   private final String sorterName;
    +
    +   /**
    +    * Shows whether the order of records can be completely determined by 
the normalized
    +    * sorting key, or the sorter has to also deserialize records if their 
keys are equal to
    +    * really confirm the order.
    +    */
    +   private final boolean normalizedKeyFullyDetermines;
    +
    +   /**
    +    * Constructor.
    +    * @param typeComparator
    +    *                The type information of underlying data
    +    */
    +   SorterTemplateModel(TypeComparator typeComparator){
    +           this.typeComparator = typeComparator;
    +
    +           // number of bytes of the sorting key
    +           int numKeyBytes;
    +
    +           // compute no. bytes for sorting records and check whether 
these bytes are just a prefix or not.
    +           if (this.typeComparator.supportsNormalizedKey()) {
    +                   // compute the max normalized key length
    +                   int numPartialKeys;
    +                   try {
    +                           numPartialKeys = 
this.typeComparator.getFlatComparators().length;
    +                   } catch (Throwable t) {
    +                           numPartialKeys = 1;
    +                   }
    +
    +                   int maxLen = 
Math.min(NormalizedKeySorter.DEFAULT_MAX_NORMALIZED_KEY_LEN, 
NormalizedKeySorter.MAX_NORMALIZED_KEY_LEN_PER_ELEMENT * numPartialKeys);
    +
    +                   numKeyBytes = 
Math.min(this.typeComparator.getNormalizeKeyLen(), maxLen);
    +                   this.normalizedKeyFullyDetermines = 
!this.typeComparator.isNormalizedKeyPrefixOnly(numKeyBytes);
    +           }
    +           else {
    +                   numKeyBytes = 0;
    +                   this.normalizedKeyFullyDetermines = false;
    +           }
    +
    +           this.primitiveChunks = calculateChunks(numKeyBytes);
    +
    +           this.sorterName = generateCodeFilename(this.primitiveChunks, 
this.normalizedKeyFullyDetermines);
    +   }
    +
    +   // 
------------------------------------------------------------------------
    +   //                               Public Methods
    +   // 
------------------------------------------------------------------------
    +
    +   /**
    +    * Generate suitable sequence of operators for creating custom 
NormalizedKeySorter.
    +    * @return map of procedures and corresponding code
    +    */
    +   Map<String, String> getTemplateVariables() {
    +
    +           Map<String, String> templateVariables = new HashMap<>();
    +
    +           templateVariables.put("name", this.sorterName);
    +
    +           String swapProcedures    = generateSwapProcedures();
    +           String writeProcedures   = generateWriteProcedures();
    +           String compareProcedures = generateCompareProcedures();
    +
    +           templateVariables.put("writeProcedures", writeProcedures);
    +           templateVariables.put("swapProcedures", swapProcedures);
    +           templateVariables.put("compareProcedures", compareProcedures);
    +
    +           return templateVariables;
    +   }
    +
    +   /**
    +    * Getter for sorterName (generated in the constructor).
    +    * @return name of the sorter
    +    */
    +   String getSorterName(){
    +           return this.sorterName;
    +   }
    +
    +   // 
------------------------------------------------------------------------
    +   //                               Protected Methods
    +   // 
------------------------------------------------------------------------
    +
    +   /**
    +    * Getter for primitiveChunks.
    +    * this method is for testing purposes
    +    */
    +   ArrayList<Integer> getPrimitiveChunks(){
    +           return primitiveChunks;
    +   }
    +
    +   // 
------------------------------------------------------------------------
    +   //                               Private Methods
    +   // 
------------------------------------------------------------------------
    +
    +   /**
    +    *  Given no. of bytes, break it into chunks that can be handled by
    +    *  primitive operations (e.g., integer or long operations)
    +    *  @return ArrayList of chunk sizes
    +    */
    +   private ArrayList<Integer> calculateChunks(int numKeyBytes){
    +           ArrayList<Integer> chunks = new ArrayList<>();
    +
    +           // if no. of bytes is too large, we don't split
    +           if (numKeyBytes > SPLITTING_THRESHOLD) {
    +                   return chunks;
    +           }
    +
    +           // also include the offset because of the pointer
    +           numKeyBytes += NormalizedKeySorter.OFFSET_LEN;
    +
    +           // greedy finding of chunk sizes
    +           int i = 0;
    +           while (numKeyBytes > 0) {
    +                   int bytes = POSSIBLE_CHUNK_SIZES[i];
    +                   if (bytes <= numKeyBytes) {
    +                           chunks.add(bytes);
    +                           numKeyBytes -= bytes;
    +                   } else {
    +                           i++;
    +                   }
    +           }
    +
    +           // generateCompareProcedures and generateWriteProcedures skip 
the
    +           // first 8 bytes, because it contains the pointer.
    +           // They do this by skipping the first entry of primitiveChunks, 
because that
    +           // should always be 8 in this case.
    +           if (!(NormalizedKeySorter.OFFSET_LEN == 8 && 
chunks.get(0).equals(8))) {
    +                   throw new RuntimeException("Bug: Incorrect OFFSET_LEN 
or primitiveChunks");
    +           }
    +
    +           return chunks;
    +   }
    +
    +   /**
    +    * Based on primitiveChunks variable, generate the most suitable 
operators
    +    * for swapping function.
    +    *
    +    * @return code used in the swap method
    +    */
    +   private String generateSwapProcedures(){
    +           /* Example generated code, for 20 bytes (8+8+4):
    +
    +           long temp1 = segI.getLong(segmentOffsetI);
    +           long temp2 = segI.getLong(segmentOffsetI+8);
    +           int temp3 = segI.getInt(segmentOffsetI+16);
    +
    +           segI.putLong(segmentOffsetI, segJ.getLong(segmentOffsetJ));
    +           segI.putLong(segmentOffsetI+8, segJ.getLong(segmentOffsetJ+8));
    +           segI.putInt(segmentOffsetI+16, segJ.getInt(segmentOffsetJ+16));
    +
    +           segJ.putLong(segmentOffsetJ, temp1);
    +           segJ.putLong(segmentOffsetJ+8, temp2);
    +           segJ.putInt(segmentOffsetJ+16, temp3);
    +            */
    +
    +           String procedures = "";
    +
    +           if (this.primitiveChunks.size() > 0) {
    +                   StringBuilder temporaryString     = new StringBuilder();
    +                   StringBuilder firstSegmentString  = new StringBuilder();
    +                   StringBuilder secondSegmentString = new StringBuilder();
    +
    +                   int accOffset = 0;
    +                   for (int i = 0; i  < primitiveChunks.size(); i++){
    +                           int numberByte = primitiveChunks.get(i);
    +                           int varIndex = i + 1;
    +
    +                           String primitiveClass = 
byteOperatorMapping.get(numberByte);
    +                           String primitiveType  = 
primitiveClass.toLowerCase();
    +
    +                           String offsetString = "";
    +                           if (i > 0) {
    +                                   accOffset += primitiveChunks.get(i - 1);
    +                                   offsetString = "+" + accOffset;
    +                           }
    +
    +                           temporaryString.append(String.format("%s temp%d 
= segI.get%s(segmentOffsetI%s);\n",
    +                                   primitiveType, varIndex, 
primitiveClass, offsetString));
    +
    +                           
firstSegmentString.append(String.format("segI.put%s(segmentOffsetI%s, 
segJ.get%s(segmentOffsetJ%s));\n",
    +                                   primitiveClass, offsetString, 
primitiveClass, offsetString));
    +
    +                           
secondSegmentString.append(String.format("segJ.put%s(segmentOffsetJ%s, 
temp%d);\n",
    +                                   primitiveClass, offsetString, 
varIndex));
    +
    +                   }
    +
    +                   procedures = temporaryString.toString()
    +                           + "\n" + firstSegmentString.toString()
    +                           + "\n" + secondSegmentString.toString();
    +           } else {
    +                   procedures = "segI.swapBytes(this.swapBuffer, segJ, 
segmentOffsetI, segmentOffsetJ, this.indexEntrySize);";
    +           }
    +
    +           return procedures;
    +   }
    +
    +   /**
    +    * Based on primitiveChunks variable, generate reverse byte operators 
for little endian machine
    +    * for writing a record to MemorySegment, such that later during 
comparison
    +    * we can directly use native byte order to do unsigned comparison.
    +    *
    +    * @return code used in the write method
    +    */
    +   private String generateWriteProcedures(){
    +           /* Example generated code, for 12 bytes (8+4):
    +
    +           long temp1 = 
Long.reverseBytes(this.currentSortIndexSegment.getLong(this.currentSortIndexOffset+8));
    +           
this.currentSortIndexSegment.putLong(this.currentSortIndexOffset + 8, temp1);
    +           int temp2 = 
Integer.reverseBytes(this.currentSortIndexSegment.getInt(this.currentSortIndexOffset+16));
    +           this.currentSortIndexSegment.putInt(this.currentSortIndexOffset 
+ 16, temp2);
    +            */
    +
    +           StringBuilder procedures = new StringBuilder();
    +           // skip the first chunk, which is the pointer before the key
    +           if (primitiveChunks.size() > 1 && ByteOrder.nativeOrder() == 
ByteOrder.LITTLE_ENDIAN) {
    +                   int offset = 0;
    +                   // starts from 1 because of skipping the first chunk
    +                   for (int i = 1; i < primitiveChunks.size(); i++){
    +                           int noBytes = primitiveChunks.get(i);
    +                           if (noBytes == 1){
    --- End diff --
    
    41fd844173c5e6c8f8935df5b6a7762ae45596b6


---

Reply via email to