Repository: flink Updated Branches: refs/heads/master 499b60fed -> 440137cc3
[FLINK-3140] [table] NULL value data layout in Row Serializer/Comparator This closes #1465. Project: http://git-wip-us.apache.org/repos/asf/flink/repo Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/440137cc Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/440137cc Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/440137cc Branch: refs/heads/master Commit: 440137cc3ad8ab47dff84212c81e960e0a254eaa Parents: 499b60f Author: twalthr <twal...@apache.org> Authored: Fri Jan 22 15:28:03 2016 +0100 Committer: twalthr <twal...@apache.org> Committed: Thu Jan 28 11:44:32 2016 +0100 ---------------------------------------------------------------------- .../common/typeutils/ComparatorTestBase.java | 30 ++ .../runtime/TupleComparatorTTT1Test.java | 18 +- .../runtime/TupleComparatorTTT2Test.java | 8 - flink-libraries/flink-table/pom.xml | 8 + .../table/typeinfo/NullAwareComparator.scala | 218 ++++++++++ .../api/table/typeinfo/NullMaskUtils.scala | 98 +++++ .../api/table/typeinfo/RowComparator.scala | 417 +++++++++++++++++++ .../api/table/typeinfo/RowSerializer.scala | 114 ++++- .../flink/api/table/typeinfo/RowTypeInfo.scala | 61 ++- .../table/test/GroupedAggreagationsITCase.scala | 115 ----- .../table/test/GroupedAggregationsITCase.scala | 138 ++++++ .../api/table/typeinfo/RowComparatorTest.scala | 135 ++++++ .../api/table/typeinfo/RowSerializerTest.scala | 208 +++++++++ .../api/scala/typeutils/CaseClassTypeInfo.scala | 5 +- 14 files changed, 1408 insertions(+), 165 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/flink/blob/440137cc/flink-core/src/test/java/org/apache/flink/api/common/typeutils/ComparatorTestBase.java ---------------------------------------------------------------------- diff --git a/flink-core/src/test/java/org/apache/flink/api/common/typeutils/ComparatorTestBase.java b/flink-core/src/test/java/org/apache/flink/api/common/typeutils/ComparatorTestBase.java index a8ace92..793688d 100644 --- a/flink-core/src/test/java/org/apache/flink/api/common/typeutils/ComparatorTestBase.java +++ b/flink-core/src/test/java/org/apache/flink/api/common/typeutils/ComparatorTestBase.java @@ -399,6 +399,32 @@ public abstract class ComparatorTestBase<T> extends TestLogger { } } + // -------------------------------- Key extraction tests -------------------------------------- + + @Test + @SuppressWarnings("unchecked") + public void testKeyExtraction() { + TypeComparator<T> comparator = getComparator(true); + T[] data = getSortedData(); + + for (T value : data) { + TypeComparator[] comparators = comparator.getFlatComparators(); + Object[] extractedKeys = new Object[comparators.length]; + int insertedKeys = comparator.extractKeys(value, extractedKeys, 0); + assertTrue(insertedKeys == comparators.length); + + for (int i = 0; i < insertedKeys; i++) { + // check if some keys are null, although this is not supported + if (!supportsNullKeys()) { + assertNotNull(extractedKeys[i]); + } + // compare the extracted key with itself as a basic check + // if the extracted key corresponds to the comparator + assertTrue(comparators[i].compare(extractedKeys[i], extractedKeys[i]) == 0); + } + } + } + // -------------------------------------------------------------------------------------------- protected void deepEquals(String message, T should, T is) { @@ -450,6 +476,10 @@ public abstract class ComparatorTestBase<T> extends TestLogger { } + protected boolean supportsNullKeys() { + return false; + } + // -------------------------------------------------------------------------------------------- public static final class TestOutputView extends DataOutputStream implements DataOutputView { http://git-wip-us.apache.org/repos/asf/flink/blob/440137cc/flink-java/src/test/java/org/apache/flink/api/java/typeutils/runtime/TupleComparatorTTT1Test.java ---------------------------------------------------------------------- diff --git a/flink-java/src/test/java/org/apache/flink/api/java/typeutils/runtime/TupleComparatorTTT1Test.java b/flink-java/src/test/java/org/apache/flink/api/java/typeutils/runtime/TupleComparatorTTT1Test.java index 6d4afdd..cf73be2 100644 --- a/flink-java/src/test/java/org/apache/flink/api/java/typeutils/runtime/TupleComparatorTTT1Test.java +++ b/flink-java/src/test/java/org/apache/flink/api/java/typeutils/runtime/TupleComparatorTTT1Test.java @@ -63,23 +63,7 @@ public class TupleComparatorTTT1Test extends TupleComparatorTestBase<Tuple3<Tupl new DoubleComparator(ascending) }, new TypeSerializer[] { StringSerializer.INSTANCE, - DoubleSerializer.INSTANCE }), - new TupleComparator<Tuple2<Long, Long>>( - new int[] { 0, 1 }, - new TypeComparator[] { - new LongComparator(ascending), - new LongComparator(ascending) }, - new TypeSerializer[] { - LongSerializer.INSTANCE, - LongSerializer.INSTANCE }), - new TupleComparator<Tuple2<Integer, Long>>( - new int[] { 0, 1 }, - new TypeComparator[] { - new IntComparator(ascending), - new LongComparator(ascending) }, - new TypeSerializer[] { - IntSerializer.INSTANCE, - LongSerializer.INSTANCE }) }, + DoubleSerializer.INSTANCE }) }, new TypeSerializer[] { new TupleSerializer<Tuple2<String, Double>>( (Class<Tuple2<String, Double>>) (Class<?>) Tuple2.class, http://git-wip-us.apache.org/repos/asf/flink/blob/440137cc/flink-java/src/test/java/org/apache/flink/api/java/typeutils/runtime/TupleComparatorTTT2Test.java ---------------------------------------------------------------------- diff --git a/flink-java/src/test/java/org/apache/flink/api/java/typeutils/runtime/TupleComparatorTTT2Test.java b/flink-java/src/test/java/org/apache/flink/api/java/typeutils/runtime/TupleComparatorTTT2Test.java index 56ceb59..4b07c61 100644 --- a/flink-java/src/test/java/org/apache/flink/api/java/typeutils/runtime/TupleComparatorTTT2Test.java +++ b/flink-java/src/test/java/org/apache/flink/api/java/typeutils/runtime/TupleComparatorTTT2Test.java @@ -65,14 +65,6 @@ public class TupleComparatorTTT2Test extends TupleComparatorTestBase<Tuple3<Tupl new TypeSerializer[] { StringSerializer.INSTANCE, DoubleSerializer.INSTANCE }), - new TupleComparator<Tuple2<Long, Long>>( - new int[] { 0, 1 }, - new TypeComparator[] { - new LongComparator(ascending), - new LongComparator(ascending) }, - new TypeSerializer[] { - LongSerializer.INSTANCE, - LongSerializer.INSTANCE }), new TupleComparator<Tuple2<Integer, Long>>( new int[] { 0, 1 }, new TypeComparator[] { http://git-wip-us.apache.org/repos/asf/flink/blob/440137cc/flink-libraries/flink-table/pom.xml ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/pom.xml b/flink-libraries/flink-table/pom.xml index a486525..73f77f4 100644 --- a/flink-libraries/flink-table/pom.xml +++ b/flink-libraries/flink-table/pom.xml @@ -100,6 +100,14 @@ under the License. <version>${jackson.version}</version> </dependency> + <dependency> + <groupId>org.apache.flink</groupId> + <artifactId>flink-core</artifactId> + <version>${project.version}</version> + <type>test-jar</type> + <scope>test</scope> + </dependency> + </dependencies> <build> http://git-wip-us.apache.org/repos/asf/flink/blob/440137cc/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/typeinfo/NullAwareComparator.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/typeinfo/NullAwareComparator.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/typeinfo/NullAwareComparator.scala new file mode 100644 index 0000000..ebbd471 --- /dev/null +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/typeinfo/NullAwareComparator.scala @@ -0,0 +1,218 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.api.table.typeinfo + +import org.apache.flink.api.common.typeutils.{CompositeTypeComparator, TypeComparator} +import org.apache.flink.core.memory.{DataInputView, DataOutputView, MemorySegment} + +import scala.collection.JavaConversions._ +import scala.collection.mutable.ArrayBuffer + +/** + * Null-aware comparator that wraps a comparator which does not support null references. + * + * NOTE: This class assumes to be used within a composite type comparator (such + * as [[RowComparator]]) that handles serialized comparison. + */ +class NullAwareComparator[T]( + val wrappedComparator: TypeComparator[T], + val order: Boolean) + extends TypeComparator[T] { + + // number of flat fields + private val flatFields = wrappedComparator.getFlatComparators.length + + // stores the null for reference comparison + private var nullReference = false + + override def hash(record: T): Int = { + if (record != null) { + wrappedComparator.hash(record) + } + else { + 0 + } + } + + override def getNormalizeKeyLen: Int = { + val len = wrappedComparator.getNormalizeKeyLen + if (len == Integer.MAX_VALUE) { + Integer.MAX_VALUE + } + else { + len + 1 // add one for a null byte + } + } + + override def putNormalizedKey( + record: T, + target: MemorySegment, + offset: Int, + numBytes: Int) + : Unit = { + if (numBytes > 0) { + // write a null byte with padding + if (record == null) { + target.putBoolean(offset, false) + // write padding + var j = 0 + while (j < numBytes - 1) { + target.put(offset + 1 + j, 0.toByte) + j += 1 + } + } + // write a non-null byte with key + else { + target.putBoolean(offset, true) + // write key + wrappedComparator.putNormalizedKey(record, target, offset + 1, numBytes - 1) + } + } + } + + override def invertNormalizedKey(): Boolean = wrappedComparator.invertNormalizedKey() + + override def supportsSerializationWithKeyNormalization(): Boolean = false + + override def writeWithKeyNormalization(record: T, target: DataOutputView): Unit = + throw new UnsupportedOperationException("Record serialization with leading normalized keys" + + " not supported.") + + override def readWithKeyDenormalization(reuse: T, source: DataInputView): T = + throw new UnsupportedOperationException("Record deserialization with leading normalized keys" + + " not supported.") + + override def isNormalizedKeyPrefixOnly(keyBytes: Int): Boolean = + wrappedComparator.isNormalizedKeyPrefixOnly(keyBytes - 1) + + override def setReference(toCompare: T): Unit = { + if (toCompare == null) { + nullReference = true + } + else { + nullReference = false + wrappedComparator.setReference(toCompare) + } + } + + override def compare(first: T, second: T): Int = { + // both values are null -> equality + if (first == null && second == null) { + 0 + } + // first value is null -> inequality + // but order is considered + else if (first == null) { + if (order) -1 else 1 + } + // second value is null -> inequality + // but order is considered + else if (second == null) { + if (order) 1 else -1 + } + // no null values + else { + wrappedComparator.compare(first, second) + } + } + + override def compareToReference(referencedComparator: TypeComparator[T]): Int = { + val otherComparator = referencedComparator.asInstanceOf[NullAwareComparator[T]] + val otherNullReference = otherComparator.nullReference + // both values are null -> equality + if (nullReference && otherNullReference) { + 0 + } + // first value is null -> inequality + // but order is considered + else if (nullReference) { + if (order) 1 else -1 + } + // second value is null -> inequality + // but order is considered + else if (otherNullReference) { + if (order) -1 else 1 + } + // no null values + else { + wrappedComparator.compareToReference(otherComparator.wrappedComparator) + } + } + + override def supportsNormalizedKey(): Boolean = wrappedComparator.supportsNormalizedKey() + + override def equalToReference(candidate: T): Boolean = { + // both values are null + if (candidate == null && nullReference) { + true + } + // one value is null + else if (candidate == null || nullReference) { + false + } + // no null value + else { + wrappedComparator.equalToReference(candidate) + } + } + + override def duplicate(): TypeComparator[T] = { + new NullAwareComparator[T](wrappedComparator.duplicate(), order) + } + + override def extractKeys(record: Any, target: Array[AnyRef], index: Int): Int = { + if (record == null) { + var i = 0 + while (i < flatFields) { + target(index + i) = null + i += 1 + } + flatFields + } + else { + wrappedComparator.extractKeys(record, target, index) + } + } + + + override def getFlatComparators: Array[TypeComparator[_]] = { + // determine the flat comparators and wrap them again in null-aware comparators + val flatComparators = new ArrayBuffer[TypeComparator[_]]() + wrappedComparator match { + case ctc: CompositeTypeComparator[_] => ctc.getFlatComparator(flatComparators) + case c: TypeComparator[_] => flatComparators += c + } + val wrappedComparators = flatComparators.map { c => + new NullAwareComparator[Any](c.asInstanceOf[TypeComparator[Any]], order) + } + wrappedComparators.toArray[TypeComparator[_]] + } + + /** + * This method is not implemented here. It must be implemented by the comparator this class + * is contained in (e.g. RowComparator). + * + * @param firstSource The input view containing the first record. + * @param secondSource The input view containing the second record. + * @return An integer defining the oder among the objects in the same way as + * { @link java.util.Comparator#compare(Object, Object)}. + */ + override def compareSerialized(firstSource: DataInputView, secondSource: DataInputView): Int = + throw new UnsupportedOperationException("Comparator does not support null-aware serialized " + + "comparision.") +} http://git-wip-us.apache.org/repos/asf/flink/blob/440137cc/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/typeinfo/NullMaskUtils.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/typeinfo/NullMaskUtils.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/typeinfo/NullMaskUtils.scala new file mode 100644 index 0000000..40a39dd --- /dev/null +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/typeinfo/NullMaskUtils.scala @@ -0,0 +1,98 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.api.table.typeinfo + +import org.apache.flink.api.table.Row +import org.apache.flink.core.memory.{DataInputView, DataOutputView} + +object NullMaskUtils { + + def writeNullMask(len: Int, value: Row, target: DataOutputView): Unit = { + var b = 0x00 + var bytePos = 0 + + var fieldPos = 0 + var numPos = 0 + while (fieldPos < len) { + b = 0x00 + // set bits in byte + bytePos = 0 + numPos = Math.min(8, len - fieldPos) + while (bytePos < numPos) { + b = b << 1 + // set bit if field is null + if(value.productElement(fieldPos + bytePos) == null) { + b |= 0x01 + } + bytePos += 1 + } + fieldPos += numPos + // shift bits if last byte is not completely filled + b <<= (8 - bytePos) + // write byte + target.writeByte(b) + } + } + + def readIntoNullMask(len: Int, source: DataInputView, nullMask: Array[Boolean]): Unit = { + var b = 0x00 + var bytePos = 0 + + var fieldPos = 0 + var numPos = 0 + while (fieldPos < len) { + // read byte + b = source.readUnsignedByte() + bytePos = 0 + numPos = Math.min(8, len - fieldPos) + while (bytePos < numPos) { + nullMask(fieldPos + bytePos) = (b & 0x80) > 0 + b = b << 1 + bytePos += 1 + } + fieldPos += numPos + } + } + + def readIntoAndCopyNullMask( + len: Int, + source: DataInputView, + target: DataOutputView, + nullMask: Array[Boolean]): Unit = { + var b = 0x00 + var bytePos = 0 + + var fieldPos = 0 + var numPos = 0 + while (fieldPos < len) { + // read byte + b = source.readUnsignedByte() + // copy byte + target.writeByte(b) + bytePos = 0 + numPos = Math.min(8, len - fieldPos) + while (bytePos < numPos) { + nullMask(fieldPos + bytePos) = (b & 0x80) > 0 + b = b << 1 + bytePos += 1 + } + fieldPos += numPos + } + } + +} http://git-wip-us.apache.org/repos/asf/flink/blob/440137cc/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/typeinfo/RowComparator.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/typeinfo/RowComparator.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/typeinfo/RowComparator.scala new file mode 100644 index 0000000..17c6d56 --- /dev/null +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/typeinfo/RowComparator.scala @@ -0,0 +1,417 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.api.table.typeinfo + +import java.util + +import org.apache.flink.api.common.typeutils.{CompositeTypeComparator, TypeComparator, TypeSerializer} +import org.apache.flink.api.java.typeutils.runtime.TupleComparatorBase +import org.apache.flink.api.table.Row +import org.apache.flink.api.table.typeinfo.NullMaskUtils.readIntoNullMask +import org.apache.flink.api.table.typeinfo.RowComparator.{createAuxiliaryFields, makeNullAware} +import org.apache.flink.core.memory.{DataInputView, DataOutputView, MemorySegment} +import org.apache.flink.types.KeyFieldOutOfBoundsException + +/** + * Comparator for [[Row]]. + */ +class RowComparator private ( + /** key positions describe which fields are keys in what order */ + val keyPositions: Array[Int], + /** null-aware comparators for the key fields, in the same order as the key fields */ + val comparators: Array[NullAwareComparator[Any]], + /** serializers to deserialize the first n fields for comparison */ + val serializers: Array[TypeSerializer[Any]], + /** auxiliary fields for normalized key support */ + private val auxiliaryFields: (Array[Int], Int, Int, Boolean)) + extends CompositeTypeComparator[Row] with Serializable { + + // null masks for serialized comparison + private val nullMask1 = new Array[Boolean](serializers.length) + private val nullMask2 = new Array[Boolean](serializers.length) + + // cache for the deserialized key field objects + @transient + private lazy val deserializedKeyFields1: Array[Any] = instantiateDeserializationFields() + + @transient + private lazy val deserializedKeyFields2: Array[Any] = instantiateDeserializationFields() + + // create auxiliary fields + private val normalizedKeyLengths: Array[Int] = auxiliaryFields._1 + private val numLeadingNormalizableKeys: Int = auxiliaryFields._2 + private val normalizableKeyPrefixLen: Int = auxiliaryFields._3 + private val invertNormKey: Boolean = auxiliaryFields._4 + + /** + * Intermediate constructor for creating auxiliary fields. + */ + def this( + keyPositions: Array[Int], + comparators: Array[NullAwareComparator[Any]], + serializers: Array[TypeSerializer[Any]]) = { + this( + keyPositions, + comparators, + serializers, + createAuxiliaryFields(keyPositions, comparators)) + } + + /** + * General constructor for RowComparator. + * + * @param keyPositions key positions describe which fields are keys in what order + * @param comparators non-null-aware comparators for the key fields, in the same order as + * the key fields + * @param serializers serializers to deserialize the first n fields for comparison + * @param orders sorting orders for the fields + */ + def this( + keyPositions: Array[Int], + comparators: Array[TypeComparator[Any]], + serializers: Array[TypeSerializer[Any]], + orders: Array[Boolean]) = { + this( + keyPositions, + makeNullAware(comparators, orders), + serializers) + } + + private def instantiateDeserializationFields(): Array[Any] = { + val newFields = new Array[Any](serializers.length) + var i = 0 + while (i < serializers.length) { + newFields(i) = serializers(i).createInstance() + i += 1 + } + newFields + } + + // -------------------------------------------------------------------------------------------- + // Comparator Methods + // -------------------------------------------------------------------------------------------- + + override def compareToReference(referencedComparator: TypeComparator[Row]): Int = { + val other: RowComparator = referencedComparator.asInstanceOf[RowComparator] + var i = 0 + try { + while (i < keyPositions.length) { + val comparator = comparators(i) + val otherComparator = other.comparators(i) + + val cmp = comparator.compareToReference(otherComparator) + if (cmp != 0) { + return cmp + } + i = i + 1 + } + 0 + } + catch { + case iobex: IndexOutOfBoundsException => + throw new KeyFieldOutOfBoundsException(keyPositions(i)) + } + } + + override def compareSerialized(firstSource: DataInputView, secondSource: DataInputView): Int = { + val len = serializers.length + val keyLen = keyPositions.length + + readIntoNullMask(len, firstSource, nullMask1) + readIntoNullMask(len, secondSource, nullMask2) + + // deserialize + var i = 0 + while (i < len) { + val serializer = serializers(i) + + // deserialize field 1 + if (!nullMask1(i)) { + deserializedKeyFields1(i) = serializer.deserialize(deserializedKeyFields1(i), firstSource) + } + + // deserialize field 2 + if (!nullMask2(i)) { + deserializedKeyFields2(i) = serializer.deserialize(deserializedKeyFields2(i), secondSource) + } + + i += 1 + } + + // compare + i = 0 + while (i < keyLen) { + val keyPos = keyPositions(i) + val comparator = comparators(i) + + val isNull1 = nullMask1(keyPos) + val isNull2 = nullMask2(keyPos) + + var cmp = 0 + // both values are null -> equality + if (isNull1 && isNull2) { + cmp = 0 + } + // first value is null -> inequality + else if (isNull1) { + cmp = comparator.compare(null, deserializedKeyFields2(keyPos)) + } + // second value is null -> inequality + else if (isNull2) { + cmp = comparator.compare(deserializedKeyFields1(keyPos), null) + } + // no null values + else { + cmp = comparator.compare(deserializedKeyFields1(keyPos), deserializedKeyFields2(keyPos)) + } + + if (cmp != 0) { + return cmp + } + + i += 1 + } + 0 + } + + override def supportsNormalizedKey(): Boolean = numLeadingNormalizableKeys > 0 + + override def getNormalizeKeyLen: Int = normalizableKeyPrefixLen + + override def isNormalizedKeyPrefixOnly(keyBytes: Int): Boolean = + numLeadingNormalizableKeys < keyPositions.length || + normalizableKeyPrefixLen == Integer.MAX_VALUE || + normalizableKeyPrefixLen > keyBytes + + override def invertNormalizedKey(): Boolean = invertNormKey + + override def supportsSerializationWithKeyNormalization(): Boolean = false + + override def writeWithKeyNormalization(record: Row, target: DataOutputView): Unit = + throw new UnsupportedOperationException("Record serialization with leading normalized keys " + + "not supported.") + + override def readWithKeyDenormalization(reuse: Row, source: DataInputView): Row = + throw new UnsupportedOperationException("Record deserialization with leading normalized keys " + + "not supported.") + + override def duplicate(): TypeComparator[Row] = { + // copy comparator and serializer factories + val comparatorsCopy = comparators.map(_.duplicate().asInstanceOf[NullAwareComparator[Any]]) + val serializersCopy = serializers.map(_.duplicate()) + + new RowComparator( + keyPositions, + comparatorsCopy, + serializersCopy, + auxiliaryFields) + } + + override def hash(value: Row): Int = { + var code: Int = 0 + var i = 0 + try { + while(i < keyPositions.length) { + code *= TupleComparatorBase.HASH_SALT(i & 0x1F) + val element = value.productElement(keyPositions(i)) // element can be null + code += comparators(i).hash(element) + i += 1 + } + } catch { + case iobex: IndexOutOfBoundsException => + throw new KeyFieldOutOfBoundsException(keyPositions(i)) + } + code + } + + override def setReference(toCompare: Row) { + var i = 0 + try { + while(i < keyPositions.length) { + val comparator = comparators(i) + val element = toCompare.productElement(keyPositions(i)) + comparator.setReference(element) // element can be null + i += 1 + } + } catch { + case iobex: IndexOutOfBoundsException => + throw new KeyFieldOutOfBoundsException(keyPositions(i)) + } + } + + override def equalToReference(candidate: Row): Boolean = { + var i = 0 + try { + while(i < keyPositions.length) { + val comparator = comparators(i) + val element = candidate.productElement(keyPositions(i)) // element can be null + // check if reference is not equal + if (!comparator.equalToReference(element)) { + return false + } + i += 1 + } + } catch { + case iobex: IndexOutOfBoundsException => + throw new KeyFieldOutOfBoundsException(keyPositions(i)) + } + true + } + + override def compare(first: Row, second: Row): Int = { + var i = 0 + try { + while(i < keyPositions.length) { + val keyPos: Int = keyPositions(i) + val comparator = comparators(i) + val firstElement = first.productElement(keyPos) // element can be null + val secondElement = second.productElement(keyPos) // element can be null + + val cmp = comparator.compare(firstElement, secondElement) + if (cmp != 0) { + return cmp + } + i += 1 + } + } catch { + case iobex: IndexOutOfBoundsException => + throw new KeyFieldOutOfBoundsException(keyPositions(i)) + } + 0 + } + + override def putNormalizedKey( + record: Row, + target: MemorySegment, + offset: Int, + numBytes: Int) + : Unit = { + var bytesLeft = numBytes + var currentOffset = offset + + var i = 0 + while (i < numLeadingNormalizableKeys && bytesLeft > 0) { + var len = normalizedKeyLengths(i) + len = if (bytesLeft >= len) len else bytesLeft + + val comparator = comparators(i) + val element = record.productElement(keyPositions(i)) // element can be null + // write key + comparator.putNormalizedKey(element, target, currentOffset, len) + + bytesLeft -= len + currentOffset += len + i += 1 + } + } + + override def getFlatComparator(flatComparators: util.List[TypeComparator[_]]): Unit = + comparators.foreach { c => + c.getFlatComparators.foreach { fc => + flatComparators.add(fc) + } + } + + override def extractKeys(record: Any, target: Array[AnyRef], index: Int): Int = { + val len = comparators.length + var localIndex = index + var i = 0 + while (i < len) { + val element = record.asInstanceOf[Row].productElement(keyPositions(i)) // element can be null + localIndex += comparators(i).extractKeys(element, target, localIndex) + i += 1 + } + localIndex - index + } +} + +object RowComparator { + private def makeNullAware( + comparators: Array[TypeComparator[Any]], + orders: Array[Boolean]) + : Array[NullAwareComparator[Any]] = + comparators + .zip(orders) + .map { case (comp, order) => + new NullAwareComparator[Any]( + comp, + order) + } + + /** + * @return creates auxiliary fields for normalized key support + */ + private def createAuxiliaryFields( + keyPositions: Array[Int], + comparators: Array[NullAwareComparator[Any]]) + : (Array[Int], Int, Int, Boolean) = { + + val normalizedKeyLengths = new Array[Int](keyPositions.length) + var numLeadingNormalizableKeys = 0 + var normalizableKeyPrefixLen = 0 + var inverted = false + + var i = 0 + while (i < keyPositions.length) { + val k = comparators(i) + // as long as the leading keys support normalized keys, we can build up the composite key + if (k.supportsNormalizedKey()) { + if (i == 0) { + // the first comparator decides whether we need to invert the key direction + inverted = k.invertNormalizedKey() + } + else if (k.invertNormalizedKey() != inverted) { + // if a successor does not agree on the inversion direction, it cannot be part of the + // normalized key + return (normalizedKeyLengths, + numLeadingNormalizableKeys, + normalizableKeyPrefixLen, + inverted) + } + numLeadingNormalizableKeys += 1 + val len = k.getNormalizeKeyLen + if (len < 0) { + throw new RuntimeException("Comparator " + k.getClass.getName + + " specifies an invalid length for the normalized key: " + len) + } + normalizedKeyLengths(i) = len + normalizableKeyPrefixLen += len + if (normalizableKeyPrefixLen < 0) { + // overflow, which means we are out of budget for normalized key space anyways + return (normalizedKeyLengths, + numLeadingNormalizableKeys, + Integer.MAX_VALUE, + inverted) + } + } + else { + return (normalizedKeyLengths, + numLeadingNormalizableKeys, + normalizableKeyPrefixLen, + inverted) + } + i += 1 + } + (normalizedKeyLengths, + numLeadingNormalizableKeys, + normalizableKeyPrefixLen, + inverted) + } +} + http://git-wip-us.apache.org/repos/asf/flink/blob/440137cc/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/typeinfo/RowSerializer.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/typeinfo/RowSerializer.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/typeinfo/RowSerializer.scala index 5e9613d..4f2d535 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/typeinfo/RowSerializer.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/typeinfo/RowSerializer.scala @@ -17,9 +17,10 @@ */ package org.apache.flink.api.table.typeinfo -import org.apache.flink.api.table.Row import org.apache.flink.api.common.typeutils.TypeSerializer -import org.apache.flink.core.memory.{DataOutputView, DataInputView} +import org.apache.flink.api.table.Row +import org.apache.flink.api.table.typeinfo.NullMaskUtils.{writeNullMask, readIntoNullMask, readIntoAndCopyNullMask} +import org.apache.flink.core.memory.{DataInputView, DataOutputView} /** * Serializer for [[Row]]. @@ -27,11 +28,16 @@ import org.apache.flink.core.memory.{DataOutputView, DataInputView} class RowSerializer(val fieldSerializers: Array[TypeSerializer[Any]]) extends TypeSerializer[Row] { + private val nullMask = new Array[Boolean](fieldSerializers.length) + override def isImmutableType: Boolean = false override def getLength: Int = -1 - override def duplicate = this + override def duplicate: RowSerializer = { + val duplicateFieldSerializers = fieldSerializers.map(_.duplicate()) + new RowSerializer(duplicateFieldSerializers) + } override def createInstance: Row = { new Row(fieldSerializers.length) @@ -40,15 +46,33 @@ class RowSerializer(val fieldSerializers: Array[TypeSerializer[Any]]) override def copy(from: Row, reuse: Row): Row = { val len = fieldSerializers.length - if (from.productArity != len) { - throw new RuntimeException("Row arity of reuse and from do not match.") + // cannot reuse, do a non-reuse copy + if (reuse == null) { + return copy(from) } + + if (from.productArity != len || reuse.productArity != len) { + throw new RuntimeException("Row arity of reuse or from is incompatible with this " + + "RowSerializer.") + } + var i = 0 while (i < len) { - val reuseField = reuse.productElement(i) - val fromField = from.productElement(i).asInstanceOf[AnyRef] - val copy = fieldSerializers(i).copy(fromField, reuseField) - reuse.setField(i, copy) + val fromField = from.productElement(i) + if (fromField != null) { + val reuseField = reuse.productElement(i) + if (reuseField != null) { + val copy = fieldSerializers(i).copy(fromField, reuseField) + reuse.setField(i, copy) + } + else { + val copy = fieldSerializers(i).copy(fromField) + reuse.setField(i, copy) + } + } + else { + reuse.setField(i, null) + } i += 1 } reuse @@ -58,14 +82,19 @@ class RowSerializer(val fieldSerializers: Array[TypeSerializer[Any]]) val len = fieldSerializers.length if (from.productArity != len) { - throw new RuntimeException("Row arity of reuse and from do not match.") + throw new RuntimeException("Row arity of from does not match serializers.") } val result = new Row(len) var i = 0 while (i < len) { val fromField = from.productElement(i).asInstanceOf[AnyRef] - val copy = fieldSerializers(i).copy(fromField) - result.setField(i, copy) + if (fromField != null) { + val copy = fieldSerializers(i).copy(fromField) + result.setField(i, copy) + } + else { + result.setField(i, null) + } i += 1 } result @@ -73,10 +102,22 @@ class RowSerializer(val fieldSerializers: Array[TypeSerializer[Any]]) override def serialize(value: Row, target: DataOutputView) { val len = fieldSerializers.length + + if (value.productArity != len) { + throw new RuntimeException("Row arity of value does not match serializers.") + } + + // write a null mask + writeNullMask(len, value, target) + + // serialize non-null fields var i = 0 while (i < len) { - val serializer = fieldSerializers(i) - serializer.serialize(value.productElement(i), target) + val o = value.productElement(i).asInstanceOf[AnyRef] + if (o != null) { + val serializer = fieldSerializers(i) + serializer.serialize(value.productElement(i), target) + } i += 1 } } @@ -85,13 +126,27 @@ class RowSerializer(val fieldSerializers: Array[TypeSerializer[Any]]) val len = fieldSerializers.length if (reuse.productArity != len) { - throw new RuntimeException("Row arity of reuse and fields do not match.") + throw new RuntimeException("Row arity of reuse does not match serializers.") } + // read null mask + readIntoNullMask(len, source, nullMask) + + // read non-null fields var i = 0 while (i < len) { - val field = reuse.productElement(i).asInstanceOf[AnyRef] - reuse.setField(i, fieldSerializers(i).deserialize(field, source)) + if (nullMask(i)) { + reuse.setField(i, null) + } + else { + val reuseField = reuse.productElement(i).asInstanceOf[AnyRef] + if (reuseField != null) { + reuse.setField(i, fieldSerializers(i).deserialize(reuseField, source)) + } + else { + reuse.setField(i, fieldSerializers(i).deserialize(source)) + } + } i += 1 } reuse @@ -101,9 +156,19 @@ class RowSerializer(val fieldSerializers: Array[TypeSerializer[Any]]) val len = fieldSerializers.length val result = new Row(len) + + // read null mask + readIntoNullMask(len, source, nullMask) + + // read non-null fields var i = 0 while (i < len) { - result.setField(i, fieldSerializers(i).deserialize(source)) + if (nullMask(i)) { + result.setField(i, null) + } + else { + result.setField(i, fieldSerializers(i).deserialize(source)) + } i += 1 } result @@ -111,14 +176,21 @@ class RowSerializer(val fieldSerializers: Array[TypeSerializer[Any]]) override def copy(source: DataInputView, target: DataOutputView): Unit = { val len = fieldSerializers.length + + // copy null mask + readIntoAndCopyNullMask(len, source, target, nullMask) + + // read non-null fields var i = 0 while (i < len) { - fieldSerializers(i).copy(source, target) + if (!nullMask(i)) { + fieldSerializers(i).copy(source, target) + } i += 1 } } - override def equals(any: scala.Any): Boolean = { + override def equals(any: Any): Boolean = { any match { case otherRS: RowSerializer => otherRS.canEqual(this) && @@ -127,7 +199,7 @@ class RowSerializer(val fieldSerializers: Array[TypeSerializer[Any]]) } } - override def canEqual(obj: scala.Any): Boolean = { + override def canEqual(obj: AnyRef): Boolean = { obj.isInstanceOf[RowSerializer] } http://git-wip-us.apache.org/repos/asf/flink/blob/440137cc/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/typeinfo/RowTypeInfo.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/typeinfo/RowTypeInfo.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/typeinfo/RowTypeInfo.scala index db3c881..39fc1d8 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/typeinfo/RowTypeInfo.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/typeinfo/RowTypeInfo.scala @@ -18,11 +18,14 @@ package org.apache.flink.api.table.typeinfo import org.apache.flink.api.common.ExecutionConfig -import org.apache.flink.api.table.Row import org.apache.flink.api.common.typeinfo.TypeInformation -import org.apache.flink.api.common.typeutils.TypeSerializer +import org.apache.flink.api.common.typeutils.CompositeType.TypeComparatorBuilder +import org.apache.flink.api.common.typeutils.{TypeComparator, TypeSerializer} +import org.apache.flink.api.scala.typeutils.CaseClassTypeInfo +import org.apache.flink.api.table.Row import org.apache.flink.api.table.expressions.Expression -import org.apache.flink.api.scala.typeutils.{CaseClassTypeInfo} + +import scala.collection.mutable.ArrayBuffer /** * TypeInformation for [[Row]]. @@ -32,6 +35,11 @@ class RowTypeInfo( fieldNames: Seq[String]) extends CaseClassTypeInfo[Row](classOf[Row], Array(), fieldTypes, fieldNames) { + /** + * Temporary variable for directly passing orders to comparators. + */ + var comparatorOrders: Option[Array[Boolean]] = None + def this(fields: Seq[Expression]) = this(fields.map(_.typeInfo), fields.map(_.name)) if (fieldNames.toSet.size != fieldNames.size) { @@ -47,5 +55,52 @@ class RowTypeInfo( new RowSerializer(fieldSerializers) } + + override def createComparator( + logicalKeyFields: Array[Int], + orders: Array[Boolean], + logicalFieldOffset: Int, + config: ExecutionConfig) + : TypeComparator[Row] = { + // store the order information for the builder + comparatorOrders = Some(orders) + val comparator = super.createComparator(logicalKeyFields, orders, logicalFieldOffset, config) + comparatorOrders = None + comparator + } + + override def createTypeComparatorBuilder(): TypeComparatorBuilder[Row] = { + new RowTypeComparatorBuilder(comparatorOrders.getOrElse( + throw new IllegalStateException("Cannot create comparator builder without orders."))) + } + + private class RowTypeComparatorBuilder( + comparatorOrders: Array[Boolean]) + extends TypeComparatorBuilder[Row] { + + val fieldComparators: ArrayBuffer[TypeComparator[_]] = new ArrayBuffer[TypeComparator[_]]() + val logicalKeyFields: ArrayBuffer[Int] = new ArrayBuffer[Int]() + + override def initializeTypeComparatorBuilder(size: Int): Unit = { + fieldComparators.sizeHint(size) + logicalKeyFields.sizeHint(size) + } + + override def addComparatorField(fieldId: Int, comparator: TypeComparator[_]): Unit = { + fieldComparators += comparator + logicalKeyFields += fieldId + } + + override def createTypeComparator(config: ExecutionConfig): TypeComparator[Row] = { + val maxIndex = logicalKeyFields.max + + new RowComparator( + logicalKeyFields.toArray, + fieldComparators.toArray.asInstanceOf[Array[TypeComparator[Any]]], + types.take(maxIndex + 1).map(_.createSerializer(config).asInstanceOf[TypeSerializer[Any]]), + comparatorOrders + ) + } + } } http://git-wip-us.apache.org/repos/asf/flink/blob/440137cc/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/table/test/GroupedAggreagationsITCase.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/table/test/GroupedAggreagationsITCase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/table/test/GroupedAggreagationsITCase.scala deleted file mode 100644 index fb76507..0000000 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/table/test/GroupedAggreagationsITCase.scala +++ /dev/null @@ -1,115 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.api.scala.table.test - -import org.apache.flink.api.table.{Row, ExpressionException} -import org.apache.flink.api.scala._ -import org.apache.flink.api.scala.table._ -import org.apache.flink.api.scala.util.CollectionDataSets -import org.apache.flink.core.fs.FileSystem.WriteMode -import org.apache.flink.test.util.{TestBaseUtils, MultipleProgramsTestBase} -import org.apache.flink.test.util.MultipleProgramsTestBase.TestExecutionMode -import org.junit._ -import org.junit.runner.RunWith -import org.junit.runners.Parameterized - -import scala.collection.JavaConverters._ - -@RunWith(classOf[Parameterized]) -class GroupedAggreagationsITCase(mode: TestExecutionMode) extends MultipleProgramsTestBase(mode) { - - @Test(expected = classOf[ExpressionException]) - def testGroupingOnNonExistentField(): Unit = { - - val env = ExecutionEnvironment.getExecutionEnvironment - val ds = CollectionDataSets.get3TupleDataSet(env).as('a, 'b, 'c) - .groupBy('_foo) - .select('a.avg).toDataSet[Row] - val expected = "" - val results = ds.collect() - TestBaseUtils.compareResultAsText(results.asJava, expected) - } - - @Test - def testGroupedAggregate(): Unit = { - - // the grouping key needs to be forwarded to the intermediate DataSet, even - // if we don't want the key in the output - - val env = ExecutionEnvironment.getExecutionEnvironment - val ds = CollectionDataSets.get3TupleDataSet(env).as('a, 'b, 'c) - .groupBy('b) - .select('b, 'a.sum).toDataSet[Row] - val expected = "1,1\n" + "2,5\n" + "3,15\n" + "4,34\n" + "5,65\n" + "6,111\n" - val results = ds.collect() - TestBaseUtils.compareResultAsText(results.asJava, expected) - } - - @Test - def testGroupingKeyForwardIfNotUsed(): Unit = { - - // the grouping key needs to be forwarded to the intermediate DataSet, even - // if we don't want the key in the output - - val env = ExecutionEnvironment.getExecutionEnvironment - val ds = CollectionDataSets.get3TupleDataSet(env).as('a, 'b, 'c) - .groupBy('b) - .select('a.sum).toDataSet[Row] - val expected = "1\n" + "5\n" + "15\n" + "34\n" + "65\n" + "111\n" - val results = ds.collect() - TestBaseUtils.compareResultAsText(results.asJava, expected) - } - - @Test - def testSQLStyleAggregations(): Unit = { - - // the grouping key needs to be forwarded to the intermediate DataSet, even - // if we don't want the key in the output - - val env = ExecutionEnvironment.getExecutionEnvironment - val ds = CollectionDataSets.get3TupleDataSet(env).as('a, 'b, 'c) - .select( - """Sum( a) as a1, a.sum as a2, - |Min (a) as b1, a.min as b2, - |Max (a ) as c1, a.max as c2, - |Avg ( a ) as d1, a.avg as d2, - |Count(a) as e1, a.count as e2 - """.stripMargin).toDataSet[Row] - val expected = "231,231,1,1,21,21,11,11,21,21" - val results = ds.collect() - TestBaseUtils.compareResultAsText(results.asJava, expected) - } - - @Test - def testGroupNoAggregation(): Unit = { - - val env = ExecutionEnvironment.getExecutionEnvironment - val ds = CollectionDataSets.get3TupleDataSet(env) - .as('a, 'b, 'c) - .groupBy('b) - .select('a.sum as 'd, 'b) - .groupBy('b, 'd) - .select('b) - .toDataSet[Row] - - val expected = "1\n" + "2\n" + "3\n" + "4\n" + "5\n" + "6\n" - val results = ds.collect() - TestBaseUtils.compareResultAsText(results.asJava, expected) - } -} http://git-wip-us.apache.org/repos/asf/flink/blob/440137cc/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/table/test/GroupedAggregationsITCase.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/table/test/GroupedAggregationsITCase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/table/test/GroupedAggregationsITCase.scala new file mode 100644 index 0000000..bbcf8a9 --- /dev/null +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/table/test/GroupedAggregationsITCase.scala @@ -0,0 +1,138 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.api.scala.table.test + +import org.apache.flink.api.scala._ +import org.apache.flink.api.scala.table._ +import org.apache.flink.api.scala.util.CollectionDataSets +import org.apache.flink.api.table.{ExpressionException, Row} +import org.apache.flink.test.util.MultipleProgramsTestBase.TestExecutionMode +import org.apache.flink.test.util.{MultipleProgramsTestBase, TestBaseUtils} +import org.junit._ +import org.junit.runner.RunWith +import org.junit.runners.Parameterized + +import scala.collection.JavaConverters._ + +@RunWith(classOf[Parameterized]) +class GroupedAggregationsITCase(mode: TestExecutionMode) extends MultipleProgramsTestBase(mode) { + + @Test(expected = classOf[ExpressionException]) + def testGroupingOnNonExistentField(): Unit = { + + val env = ExecutionEnvironment.getExecutionEnvironment + val ds = CollectionDataSets.get3TupleDataSet(env).as('a, 'b, 'c) + .groupBy('_foo) + .select('a.avg).toDataSet[Row] + val expected = "" + val results = ds.collect() + TestBaseUtils.compareResultAsText(results.asJava, expected) + } + + @Test + def testGroupedAggregate(): Unit = { + + // the grouping key needs to be forwarded to the intermediate DataSet, even + // if we don't want the key in the output + + val env = ExecutionEnvironment.getExecutionEnvironment + val ds = CollectionDataSets.get3TupleDataSet(env).as('a, 'b, 'c) + .groupBy('b) + .select('b, 'a.sum).toDataSet[Row] + val expected = "1,1\n" + "2,5\n" + "3,15\n" + "4,34\n" + "5,65\n" + "6,111\n" + val results = ds.collect() + TestBaseUtils.compareResultAsText(results.asJava, expected) + } + + @Test + def testGroupingKeyForwardIfNotUsed(): Unit = { + + // the grouping key needs to be forwarded to the intermediate DataSet, even + // if we don't want the key in the output + + val env = ExecutionEnvironment.getExecutionEnvironment + val ds = CollectionDataSets.get3TupleDataSet(env).as('a, 'b, 'c) + .groupBy('b) + .select('a.sum).toDataSet[Row] + val expected = "1\n" + "5\n" + "15\n" + "34\n" + "65\n" + "111\n" + val results = ds.collect() + TestBaseUtils.compareResultAsText(results.asJava, expected) + } + + @Test + def testSQLStyleAggregations(): Unit = { + + // the grouping key needs to be forwarded to the intermediate DataSet, even + // if we don't want the key in the output + + val env = ExecutionEnvironment.getExecutionEnvironment + val ds = CollectionDataSets.get3TupleDataSet(env).as('a, 'b, 'c) + .select( + """Sum( a) as a1, a.sum as a2, + |Min (a) as b1, a.min as b2, + |Max (a ) as c1, a.max as c2, + |Avg ( a ) as d1, a.avg as d2, + |Count(a) as e1, a.count as e2 + """.stripMargin).toDataSet[Row] + val expected = "231,231,1,1,21,21,11,11,21,21" + val results = ds.collect() + TestBaseUtils.compareResultAsText(results.asJava, expected) + } + + @Test + def testGroupNoAggregation(): Unit = { + + val env = ExecutionEnvironment.getExecutionEnvironment + val ds = CollectionDataSets.get3TupleDataSet(env) + .as('a, 'b, 'c) + .groupBy('b) + .select('a.sum as 'd, 'b) + .groupBy('b, 'd) + .select('b) + .toDataSet[Row] + + val expected = "1\n" + "2\n" + "3\n" + "4\n" + "5\n" + "6\n" + val results = ds.collect() + TestBaseUtils.compareResultAsText(results.asJava, expected) + } + + @Test + def testGroupedAggregateWithLongKeys(): Unit = { + // This uses very long keys to force serialized comparison. + // With short keys, the normalized key is sufficient. + val env = ExecutionEnvironment.getExecutionEnvironment + val ds = env.fromElements( + ("hhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhaa", 1, 2), + ("hhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhaa", 1, 2), + ("hhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhaa", 1, 2), + ("hhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhaa", 1, 2), + ("hhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhaa", 1, 2), + ("hhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhab", 1, 2), + ("hhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhab", 1, 2), + ("hhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhab", 1, 2), + ("hhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhab", 1, 2)) + .rebalance().setParallelism(2).as('a, 'b, 'c) + .groupBy('a, 'b) + .select('c.sum) + + val expected = "10\n" + "8\n" + val results = ds.collect() + TestBaseUtils.compareResultAsText(results.asJava, expected) + } +} http://git-wip-us.apache.org/repos/asf/flink/blob/440137cc/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/typeinfo/RowComparatorTest.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/typeinfo/RowComparatorTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/typeinfo/RowComparatorTest.scala new file mode 100644 index 0000000..a58d0b7 --- /dev/null +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/typeinfo/RowComparatorTest.scala @@ -0,0 +1,135 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.api.table.typeinfo + +import org.apache.flink.api.common.ExecutionConfig +import org.apache.flink.api.common.typeinfo.BasicTypeInfo +import org.apache.flink.api.common.typeutils.{ComparatorTestBase, TypeComparator, TypeSerializer} +import org.apache.flink.api.java.tuple +import org.apache.flink.api.java.typeutils.{TupleTypeInfo, TypeExtractor} +import org.apache.flink.api.table.Row +import org.apache.flink.api.table.typeinfo.RowComparatorTest.MyPojo +import org.junit.Assert._ + +class RowComparatorTest extends ComparatorTestBase[Row] { + + val typeInfo = new RowTypeInfo( + Array( + BasicTypeInfo.INT_TYPE_INFO, + BasicTypeInfo.DOUBLE_TYPE_INFO, + BasicTypeInfo.STRING_TYPE_INFO, + new TupleTypeInfo[tuple.Tuple2[Int, Boolean]]( + BasicTypeInfo.INT_TYPE_INFO, + BasicTypeInfo.BOOLEAN_TYPE_INFO, + BasicTypeInfo.SHORT_TYPE_INFO), + TypeExtractor.createTypeInfo(classOf[MyPojo])), + Array("f0", "f1", "f2", "f3", "f4")) + + val testPojo1 = new MyPojo() + // TODO we cannot test null here as PojoComparator has no support for null keys + testPojo1.name = "" + val testPojo2 = new MyPojo() + testPojo2.name = "Test1" + val testPojo3 = new MyPojo() + testPojo3.name = "Test2" + + val data: Array[Row] = Array( + createRow(null, null, null, null, null), + createRow(0, null, null, null, null), + createRow(0, 0.0, null, null, null), + createRow(0, 0.0, "a", null, null), + createRow(1, 0.0, "a", null, null), + createRow(1, 1.0, "a", null, null), + createRow(1, 1.0, "b", null, null), + createRow(1, 1.0, "b", new tuple.Tuple3[Int, Boolean, Short](1, false, 2), null), + createRow(1, 1.0, "b", new tuple.Tuple3[Int, Boolean, Short](2, false, 2), null), + createRow(1, 1.0, "b", new tuple.Tuple3[Int, Boolean, Short](2, true, 2), null), + createRow(1, 1.0, "b", new tuple.Tuple3[Int, Boolean, Short](2, true, 3), null), + createRow(1, 1.0, "b", new tuple.Tuple3[Int, Boolean, Short](2, true, 3), testPojo1), + createRow(1, 1.0, "b", new tuple.Tuple3[Int, Boolean, Short](2, true, 3), testPojo2), + createRow(1, 1.0, "b", new tuple.Tuple3[Int, Boolean, Short](2, true, 3), testPojo3) + ) + + override protected def deepEquals(message: String, should: Row, is: Row): Unit = { + val arity = should.productArity + assertEquals(message, arity, is.productArity) + var index = 0 + while (index < arity) { + val copiedValue: Any = should.productElement(index) + val element: Any = is.productElement(index) + assertEquals(message, element, copiedValue) + index += 1 + } + } + + override protected def createComparator(ascending: Boolean): TypeComparator[Row] = { + typeInfo.createComparator( + Array(0, 1, 2, 3, 4, 5, 6), + Array(ascending, ascending, ascending, ascending, ascending, ascending, ascending), + 0, + new ExecutionConfig()) + } + + override protected def createSerializer(): TypeSerializer[Row] = { + typeInfo.createSerializer(new ExecutionConfig()) + } + + override protected def getSortedTestData: Array[Row] = { + data + } + + override protected def supportsNullKeys: Boolean = true + + def createRow(f0: Any, f1: Any, f2: Any, f3: Any, f4: Any): Row = { + val r: Row = new Row(5) + r.setField(0, f0) + r.setField(1, f1) + r.setField(2, f2) + r.setField(3, f3) + r.setField(4, f4) + r + } +} + +object RowComparatorTest { + class MyPojo() extends Serializable with Comparable[MyPojo] { + // we cannot use null because the PojoComparator does not support null properly + var name: String = "" + + override def compareTo(o: MyPojo): Int = { + if (name == null && o.name == null) { + 0 + } + else if (name == null) { + -1 + } + else if (o.name == null) { + 1 + } + else { + name.compareTo(o.name) + } + } + + override def equals(other: Any): Boolean = other match { + case that: MyPojo => compareTo(that) == 0 + case _ => false + } + } +} http://git-wip-us.apache.org/repos/asf/flink/blob/440137cc/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/typeinfo/RowSerializerTest.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/typeinfo/RowSerializerTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/typeinfo/RowSerializerTest.scala new file mode 100644 index 0000000..60a02ae --- /dev/null +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/typeinfo/RowSerializerTest.scala @@ -0,0 +1,208 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.api.table.typeinfo + +import org.apache.flink.api.common.ExecutionConfig +import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, TypeInformation} +import org.apache.flink.api.common.typeutils.{SerializerTestInstance, TypeSerializer} +import org.apache.flink.api.java.tuple +import org.apache.flink.api.java.typeutils.{TypeExtractor, TupleTypeInfo} +import org.apache.flink.api.table.Row +import org.apache.flink.api.table.typeinfo.RowSerializerTest.MyPojo +import org.junit.Assert._ +import org.junit.Test + +class RowSerializerTest { + + class RowSerializerTestInstance( + serializer: TypeSerializer[Row], + testData: Array[Row]) + extends SerializerTestInstance[Row](serializer, classOf[Row], -1, testData: _*) { + + override protected def deepEquals(message: String, should: Row, is: Row): Unit = { + val arity = should.productArity + assertEquals(message, arity, is.productArity) + var index = 0 + while (index < arity) { + val copiedValue: Any = should.productElement(index) + val element: Any = is.productElement(index) + assertEquals(message, element, copiedValue) + index += 1 + } + } + } + + @Test + def testRowSerializer(): Unit = { + val rowInfo: TypeInformation[Row] = new RowTypeInfo( + Seq(BasicTypeInfo.INT_TYPE_INFO, BasicTypeInfo.STRING_TYPE_INFO), Seq("id", "name")) + + val row1 = new Row(2) + row1.setField(0, 1) + row1.setField(1, "a") + + val row2 = new Row(2) + row2.setField(0, 2) + row2.setField(1, null) + + val testData: Array[Row] = Array(row1, row2) + + val rowSerializer: TypeSerializer[Row] = rowInfo.createSerializer(new ExecutionConfig) + + val testInstance = new RowSerializerTestInstance(rowSerializer, testData) + + testInstance.testAll() + } + + @Test + def testLargeRowSerializer(): Unit = { + val rowInfo: TypeInformation[Row] = new RowTypeInfo(Seq( + BasicTypeInfo.INT_TYPE_INFO, + BasicTypeInfo.INT_TYPE_INFO, + BasicTypeInfo.INT_TYPE_INFO, + BasicTypeInfo.INT_TYPE_INFO, + BasicTypeInfo.INT_TYPE_INFO, + BasicTypeInfo.INT_TYPE_INFO, + BasicTypeInfo.INT_TYPE_INFO, + BasicTypeInfo.INT_TYPE_INFO, + BasicTypeInfo.INT_TYPE_INFO, + BasicTypeInfo.INT_TYPE_INFO, + BasicTypeInfo.INT_TYPE_INFO, + BasicTypeInfo.INT_TYPE_INFO, + BasicTypeInfo.STRING_TYPE_INFO), Seq( + "id0", + "id1", + "id2", + "id3", + "id4", + "id5", + "id6", + "id7", + "id8", + "id9", + "id10", + "id11", + "name")) + + val row = new Row(13) + row.setField(0, 2) + row.setField(1, null) + row.setField(3, null) + row.setField(4, null) + row.setField(5, null) + row.setField(6, null) + row.setField(7, null) + row.setField(8, null) + row.setField(9, null) + row.setField(10, null) + row.setField(11, null) + row.setField(12, "Test") + + val testData: Array[Row] = Array(row) + + val rowSerializer: TypeSerializer[Row] = rowInfo.createSerializer(new ExecutionConfig) + + val testInstance = new RowSerializerTestInstance(rowSerializer, testData) + + testInstance.testAll() + } + + @Test + def testRowSerializerWithComplexTypes(): Unit = { + val rowInfo = new RowTypeInfo( + Array( + BasicTypeInfo.INT_TYPE_INFO, + BasicTypeInfo.DOUBLE_TYPE_INFO, + BasicTypeInfo.STRING_TYPE_INFO, + new TupleTypeInfo[tuple.Tuple2[Int, Boolean]]( + BasicTypeInfo.INT_TYPE_INFO, + BasicTypeInfo.BOOLEAN_TYPE_INFO, + BasicTypeInfo.SHORT_TYPE_INFO), + TypeExtractor.createTypeInfo(classOf[MyPojo])), + Array("f0", "f1", "f2", "f3", "f4")) + + val testPojo1 = new MyPojo() + testPojo1.name = null + val testPojo2 = new MyPojo() + testPojo2.name = "Test1" + val testPojo3 = new MyPojo() + testPojo3.name = "Test2" + + val testData: Array[Row] = Array( + createRow(null, null, null, null, null), + createRow(0, null, null, null, null), + createRow(0, 0.0, null, null, null), + createRow(0, 0.0, "a", null, null), + createRow(1, 0.0, "a", null, null), + createRow(1, 1.0, "a", null, null), + createRow(1, 1.0, "b", null, null), + createRow(1, 1.0, "b", new tuple.Tuple3[Int, Boolean, Short](1, false, 2), null), + createRow(1, 1.0, "b", new tuple.Tuple3[Int, Boolean, Short](2, false, 2), null), + createRow(1, 1.0, "b", new tuple.Tuple3[Int, Boolean, Short](2, true, 2), null), + createRow(1, 1.0, "b", new tuple.Tuple3[Int, Boolean, Short](2, true, 3), null), + createRow(1, 1.0, "b", new tuple.Tuple3[Int, Boolean, Short](2, true, 3), testPojo1), + createRow(1, 1.0, "b", new tuple.Tuple3[Int, Boolean, Short](2, true, 3), testPojo2), + createRow(1, 1.0, "b", new tuple.Tuple3[Int, Boolean, Short](2, true, 3), testPojo3) + ) + + val rowSerializer: TypeSerializer[Row] = rowInfo.createSerializer(new ExecutionConfig) + + val testInstance = new RowSerializerTestInstance(rowSerializer, testData) + + testInstance.testAll() + } + + // ---------------------------------------------------------------------------------------------- + + def createRow(f0: Any, f1: Any, f2: Any, f3: Any, f4: Any): Row = { + val r: Row = new Row(5) + r.setField(0, f0) + r.setField(1, f1) + r.setField(2, f2) + r.setField(3, f3) + r.setField(4, f4) + r + } +} + +object RowSerializerTest { + class MyPojo() extends Serializable with Comparable[MyPojo] { + var name: String = null + + override def compareTo(o: MyPojo): Int = { + if (name == null && o.name == null) { + 0 + } + else if (name == null) { + -1 + } + else if (o.name == null) { + 1 + } + else { + name.compareTo(o.name) + } + } + + override def equals(other: Any): Boolean = other match { + case that: MyPojo => compareTo(that) == 0 + case _ => false + } + } +} http://git-wip-us.apache.org/repos/asf/flink/blob/440137cc/flink-scala/src/main/scala/org/apache/flink/api/scala/typeutils/CaseClassTypeInfo.scala ---------------------------------------------------------------------- diff --git a/flink-scala/src/main/scala/org/apache/flink/api/scala/typeutils/CaseClassTypeInfo.scala b/flink-scala/src/main/scala/org/apache/flink/api/scala/typeutils/CaseClassTypeInfo.scala index 37c7431..5db7a91 100644 --- a/flink-scala/src/main/scala/org/apache/flink/api/scala/typeutils/CaseClassTypeInfo.scala +++ b/flink-scala/src/main/scala/org/apache/flink/api/scala/typeutils/CaseClassTypeInfo.scala @@ -205,7 +205,10 @@ abstract class CaseClassTypeInfo[T <: Product]( val fieldComparators: ArrayBuffer[TypeComparator[_]] = new ArrayBuffer[TypeComparator[_]]() val logicalKeyFields: ArrayBuffer[Int] = new ArrayBuffer[Int]() - override def initializeTypeComparatorBuilder(size: Int): Unit = {} + override def initializeTypeComparatorBuilder(size: Int): Unit = { + fieldComparators.sizeHint(size) + logicalKeyFields.sizeHint(size) + } override def addComparatorField(fieldId: Int, comparator: TypeComparator[_]): Unit = { fieldComparators += comparator