Repository: spark
Updated Branches:
  refs/heads/master 712f5b7a9 -> 03377d252


[SPARK-9358][SQL] Code generation for UnsafeRow joiner.

This patch creates a code generated unsafe row concatenator that can be used to 
concatenate/join two UnsafeRows into a single UnsafeRow.

Since it is inherently hard to test these low level stuff, the test suites 
employ randomized testing heavily in order to guarantee correctness.

Author: Reynold Xin <r...@databricks.com>

Closes #7821 from rxin/rowconcat and squashes the following commits:

8717f35 [Reynold Xin] Rebase and code review.
72c5d8e [Reynold Xin] Fixed a bug.
a84ed2e [Reynold Xin] Fixed offset.
40c3fb2 [Reynold Xin] Reset random data generator.
f0913aa [Reynold Xin] Test fixes.
6687b6f [Reynold Xin] Updated documentation.
00354b9 [Reynold Xin] Support concat data as well.
e9a4347 [Reynold Xin] Updated.
6269f96 [Reynold Xin] Fixed a bug .
0f89716 [Reynold Xin] [SPARK-9358][SQL][WIP] Code generation for UnsafeRow 
concat.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/03377d25
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/03377d25
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/03377d25

Branch: refs/heads/master
Commit: 03377d2522776267a07b7d6ae9bddf79a4e0f516
Parents: 712f5b7
Author: Reynold Xin <r...@databricks.com>
Authored: Fri Jul 31 21:09:00 2015 -0700
Committer: Reynold Xin <r...@databricks.com>
Committed: Fri Jul 31 21:09:00 2015 -0700

----------------------------------------------------------------------
 .../sql/catalyst/expressions/UnsafeRow.java     |  19 ++
 .../expressions/codegen/CodeGenerator.scala     |   2 +
 .../codegen/GenerateUnsafeProjection.scala      |   6 +-
 .../codegen/GenerateUnsafeRowJoiner.scala       | 241 +++++++++++++++++++
 .../apache/spark/sql/RandomDataGenerator.scala  |  15 +-
 .../GenerateUnsafeRowJoinerBitsetSuite.scala    | 147 +++++++++++
 .../codegen/GenerateUnsafeRowJoinerSuite.scala  | 114 +++++++++
 .../UnsafeFixedWidthAggregationMap.java         |   7 +-
 .../spark/sql/execution/TungstenSortSuite.scala |   3 +
 9 files changed, 544 insertions(+), 10 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/03377d25/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
 
b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
index e7088ed..24dc80b 100644
--- 
a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
+++ 
b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
@@ -85,6 +85,14 @@ public final class UnsafeRow extends MutableRow {
         })));
   }
 
+  public static boolean isFixedLength(DataType dt) {
+    if (dt instanceof DecimalType) {
+      return ((DecimalType) dt).precision() < Decimal.MAX_LONG_DIGITS();
+    } else {
+      return settableFieldTypes.contains(dt);
+    }
+  }
+
   
//////////////////////////////////////////////////////////////////////////////
   // Private fields and methods
   
//////////////////////////////////////////////////////////////////////////////
@@ -144,6 +152,17 @@ public final class UnsafeRow extends MutableRow {
     this.sizeInBytes = sizeInBytes;
   }
 
+  /**
+   * Update this UnsafeRow to point to the underlying byte array.
+   *
+   * @param buf byte array to point to
+   * @param numFields the number of fields in this row
+   * @param sizeInBytes the number of bytes valid in the byte array
+   */
+  public void pointTo(byte[] buf, int numFields, int sizeInBytes) {
+    pointTo(buf, PlatformDependent.BYTE_ARRAY_OFFSET, numFields, sizeInBytes);
+  }
+
   private void assertIndexIsValid(int index) {
     assert index >= 0 : "index (" + index + ") should >= 0";
     assert index < numFields : "index (" + index + ") should < " + numFields;

http://git-wip-us.apache.org/repos/asf/spark/blob/03377d25/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
index e50ec27..36f4e9c 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
@@ -27,6 +27,7 @@ import org.apache.spark.Logging
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.PlatformDependent
 import org.apache.spark.unsafe.types._
 
 
@@ -293,6 +294,7 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: 
AnyRef] extends Loggin
     val evaluator = new ClassBodyEvaluator()
     evaluator.setParentClassLoader(getClass.getClassLoader)
     evaluator.setDefaultImports(Array(
+      classOf[PlatformDependent].getName,
       classOf[InternalRow].getName,
       classOf[UnsafeRow].getName,
       classOf[UTF8String].getName,

http://git-wip-us.apache.org/repos/asf/spark/blob/03377d25/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
index 1d22398..6c99086 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
@@ -266,16 +266,16 @@ object GenerateUnsafeProjection extends 
CodeGenerator[Seq[Expression], UnsafePro
 
     val code = s"""
       public Object generate($exprType[] exprs) {
-        return new SpecificProjection(exprs);
+        return new SpecificUnsafeProjection(exprs);
       }
 
-      class SpecificProjection extends ${classOf[UnsafeProjection].getName} {
+      class SpecificUnsafeProjection extends 
${classOf[UnsafeProjection].getName} {
 
         private $exprType[] expressions;
 
         ${declareMutableStates(ctx)}
 
-        public SpecificProjection($exprType[] expressions) {
+        public SpecificUnsafeProjection($exprType[] expressions) {
           this.expressions = expressions;
           ${initMutableStates(ctx)}
         }

http://git-wip-us.apache.org/repos/asf/spark/blob/03377d25/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala
new file mode 100644
index 0000000..645eb48
--- /dev/null
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala
@@ -0,0 +1,241 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions.codegen
+
+import org.apache.spark.sql.catalyst.expressions.{UnsafeRow, Attribute}
+import org.apache.spark.sql.types.StructType
+import org.apache.spark.unsafe.PlatformDependent
+
+
+abstract class UnsafeRowJoiner {
+  def join(row1: UnsafeRow, row2: UnsafeRow): UnsafeRow
+}
+
+
+/**
+ * A code generator for concatenating two [[UnsafeRow]]s into a single 
[[UnsafeRow]].
+ *
+ * The high level algorithm is:
+ *
+ * 1. Concatenate the two bitsets together into a single one, taking padding 
into account.
+ * 2. Move fixed-length data.
+ * 3. Move variable-length data.
+ * 4. Update the offset position (i.e. the upper 32 bits in the fixed length 
part) for all
+ *    variable-length data.
+ */
+object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), 
UnsafeRowJoiner] {
+
+  def dump(word: Long): String = {
+    Seq.tabulate(64) { i => if ((word >> i) % 2 == 0) "0" else "1" 
}.reverse.mkString
+  }
+
+  override protected def create(in: (StructType, StructType)): UnsafeRowJoiner 
= {
+    create(in._1, in._2)
+  }
+
+  override protected def canonicalize(in: (StructType, StructType)): 
(StructType, StructType) = in
+
+  override protected def bind(in: (StructType, StructType), inputSchema: 
Seq[Attribute])
+    : (StructType, StructType) = {
+    in
+  }
+
+  def create(schema1: StructType, schema2: StructType): UnsafeRowJoiner = {
+    val ctx = newCodeGenContext()
+    val offset = PlatformDependent.BYTE_ARRAY_OFFSET
+
+    val bitset1Words = (schema1.size + 63) / 64
+    val bitset2Words = (schema2.size + 63) / 64
+    val outputBitsetWords = (schema1.size + schema2.size + 63) / 64
+    val bitset1Remainder = schema1.size % 64
+    val bitset2Remainder = schema2.size % 64
+
+    // The number of words we can reduce when we concat two rows together.
+    // The only reduction comes from merging the bitset portion of the two 
rows, saving 1 word.
+    val sizeReduction = bitset1Words + bitset2Words - outputBitsetWords
+
+    // --------------------- copy bitset from row 1 ----------------------- //
+    val copyBitset1 = Seq.tabulate(bitset1Words) { i =>
+      s"""
+         |PlatformDependent.UNSAFE.putLong(buf, ${offset + i * 8},
+         |  PlatformDependent.UNSAFE.getLong(obj1, ${offset + i * 8}));
+       """.stripMargin
+    }.mkString
+
+
+    // --------------------- copy bitset from row 2 ----------------------- //
+    var copyBitset2 = ""
+    if (bitset1Remainder == 0) {
+      copyBitset2 += Seq.tabulate(bitset2Words) { i =>
+        s"""
+           |PlatformDependent.UNSAFE.putLong(buf, ${offset + (bitset1Words + 
i) * 8},
+           |  PlatformDependent.UNSAFE.getLong(obj2, ${offset + i * 8}));
+         """.stripMargin
+      }.mkString
+    } else {
+      copyBitset2 = Seq.tabulate(bitset2Words) { i =>
+        s"""
+           |long bs2w$i = PlatformDependent.UNSAFE.getLong(obj2, ${offset + i 
* 8});
+           |long bs2w${i}p1 = (bs2w$i << $bitset1Remainder) & ~((1L << 
$bitset1Remainder) - 1);
+           |long bs2w${i}p2 = (bs2w$i >>> ${64 - bitset1Remainder});
+         """.stripMargin
+      }.mkString
+
+      copyBitset2 += Seq.tabulate(bitset2Words) { i =>
+        val currentOffset = offset + (bitset1Words + i - 1) * 8
+        if (i == 0) {
+          if (bitset1Words > 0) {
+            s"""
+               |PlatformDependent.UNSAFE.putLong(buf, $currentOffset,
+               |  bs2w${i}p1 | PlatformDependent.UNSAFE.getLong(obj1, 
$currentOffset));
+            """.stripMargin
+          } else {
+            s"""
+               |PlatformDependent.UNSAFE.putLong(buf, $currentOffset + 8, 
bs2w${i}p1);
+            """.stripMargin
+          }
+        } else {
+          s"""
+             |PlatformDependent.UNSAFE.putLong(buf, $currentOffset, bs2w${i}p1 
| bs2w${i - 1}p2);
+          """.stripMargin
+        }
+      }.mkString("\n")
+
+      if (bitset2Words > 0 &&
+        (bitset2Remainder == 0 || bitset2Remainder > (64 - bitset1Remainder))) 
{
+        val lastWord = bitset2Words - 1
+        copyBitset2 +=
+          s"""
+             |PlatformDependent.UNSAFE.putLong(buf, ${offset + 
(outputBitsetWords - 1) * 8},
+             |  bs2w${lastWord}p2);
+          """.stripMargin
+      }
+    }
+
+    // --------------------- copy fixed length portion from row 1 
----------------------- //
+    var cursor = offset + outputBitsetWords * 8
+    val copyFixedLengthRow1 = s"""
+       |// Copy fixed length data for row1
+       |PlatformDependent.copyMemory(
+       |  obj1, offset1 + ${bitset1Words * 8},
+       |  buf, $cursor,
+       |  ${schema1.size * 8});
+     """.stripMargin
+    cursor += schema1.size * 8
+
+    // --------------------- copy fixed length portion from row 2 
----------------------- //
+    val copyFixedLengthRow2 = s"""
+       |// Copy fixed length data for row2
+       |PlatformDependent.copyMemory(
+       |  obj2, offset2 + ${bitset2Words * 8},
+       |  buf, $cursor,
+       |  ${schema2.size * 8});
+     """.stripMargin
+    cursor += schema2.size * 8
+
+    // --------------------- copy variable length portion from row 1 
----------------------- //
+    val copyVariableLengthRow1 = s"""
+       |// Copy variable length data for row1
+       |long numBytesBitsetAndFixedRow1 = ${(bitset1Words + schema1.size) * 8};
+       |long numBytesVariableRow1 = row1.getSizeInBytes() - 
numBytesBitsetAndFixedRow1;
+       |PlatformDependent.copyMemory(
+       |  obj1, offset1 + ${(bitset1Words + schema1.size) * 8},
+       |  buf, $cursor,
+       |  numBytesVariableRow1);
+     """.stripMargin
+
+    // --------------------- copy variable length portion from row 2 
----------------------- //
+    val copyVariableLengthRow2 = s"""
+       |// Copy variable length data for row2
+       |long numBytesBitsetAndFixedRow2 = ${(bitset2Words + schema2.size) * 8};
+       |long numBytesVariableRow2 = row2.getSizeInBytes() - 
numBytesBitsetAndFixedRow2;
+       |PlatformDependent.copyMemory(
+       |  obj2, offset2 + ${(bitset2Words + schema2.size) * 8},
+       |  buf, $cursor + numBytesVariableRow1,
+       |  numBytesVariableRow2);
+     """.stripMargin
+
+    // ------------- update fixed length data for variable length data type  
--------------- //
+    val updateOffset = (schema1 ++ schema2).zipWithIndex.map { case (field, i) 
=>
+      // Skip fixed length data types, and only generate code for variable 
length data
+      if (UnsafeRow.isFixedLength(field.dataType)) {
+        ""
+      } else {
+        // Number of bytes to increase for the offset. Note that since in 
UnsafeRow we store the
+        // offset in the upper 32 bit of the words, we can just shift the 
offset to the left by
+        // 32 and increment that amount in place.
+        val shift =
+          if (i < schema1.size) {
+            s"${(outputBitsetWords - bitset1Words + schema2.size) * 8}L"
+          } else {
+            s"${(outputBitsetWords - bitset2Words + schema1.size) * 8}L + 
numBytesVariableRow1"
+          }
+        val cursor = offset + outputBitsetWords * 8 + i * 8
+        s"""
+           |PlatformDependent.UNSAFE.putLong(buf, $cursor,
+           |  PlatformDependent.UNSAFE.getLong(buf, $cursor) + ($shift << 32));
+         """.stripMargin
+      }
+    }.mkString
+
+    // ------------------------ Finally, put everything together  
--------------------------- //
+    val code = s"""
+       |public Object generate($exprType[] exprs) {
+       |  return new SpecificUnsafeRowJoiner();
+       |}
+       |
+       |class SpecificUnsafeRowJoiner extends 
${classOf[UnsafeRowJoiner].getName} {
+       |  private byte[] buf = new byte[64];
+       |  private UnsafeRow out = new UnsafeRow();
+       |
+       |  public UnsafeRow join(UnsafeRow row1, UnsafeRow row2) {
+       |    // row1: ${schema1.size} fields, $bitset1Words words in bitset
+       |    // row2: ${schema2.size}, $bitset2Words words in bitset
+       |    // output: ${schema1.size + schema2.size} fields, 
$outputBitsetWords words in bitset
+       |    final int sizeInBytes = row1.getSizeInBytes() + 
row2.getSizeInBytes();
+       |    if (sizeInBytes > buf.length) {
+       |      buf = new byte[sizeInBytes];
+       |    }
+       |
+       |    final Object obj1 = row1.getBaseObject();
+       |    final long offset1 = row1.getBaseOffset();
+       |    final Object obj2 = row2.getBaseObject();
+       |    final long offset2 = row2.getBaseOffset();
+       |
+       |    $copyBitset1
+       |    $copyBitset2
+       |    $copyFixedLengthRow1
+       |    $copyFixedLengthRow2
+       |    $copyVariableLengthRow1
+       |    $copyVariableLengthRow2
+       |    $updateOffset
+       |
+       |    out.pointTo(buf, ${schema1.size + schema2.size}, sizeInBytes - 
$sizeReduction);
+       |
+       |    return out;
+       |  }
+       |}
+     """.stripMargin
+
+    logDebug(s"SpecificUnsafeRowJoiner($schema1, 
$schema2):\n${CodeFormatter.format(code)}")
+    // println(CodeFormatter.format(code))
+
+    val c = compile(code)
+    c.generate(Array.empty).asInstanceOf[UnsafeRowJoiner]
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/03377d25/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala
index 75ae29d..81267dc 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala
@@ -66,6 +66,19 @@ object RandomDataGenerator {
   }
 
   /**
+   * Returns a randomly generated schema, based on the given accepted types.
+   *
+   * @param numFields the number of fields in this schema
+   * @param acceptedTypes types to draw from.
+   */
+  def randomSchema(numFields: Int, acceptedTypes: Seq[DataType]): StructType = 
{
+    StructType(Seq.tabulate(numFields) { i =>
+      val dt = acceptedTypes(Random.nextInt(acceptedTypes.size))
+      StructField("col_" + i, dt, nullable = true)
+    })
+  }
+
+  /**
    * Returns a function which generates random values for the given 
[[DataType]], or `None` if no
    * random data generator is defined for that data type. The generated values 
will use an external
    * representation of the data type; for example, the random generator for 
[[DateType]] will return
@@ -94,7 +107,7 @@ object RandomDataGenerator {
       case DateType => Some(() => new java.sql.Date(rand.nextInt()))
       case TimestampType => Some(() => new java.sql.Timestamp(rand.nextLong()))
       case DecimalType.Fixed(precision, scale) => Some(
-        () => BigDecimal.apply(rand.nextLong, rand.nextInt, new 
MathContext(precision)))
+        () => BigDecimal.apply(rand.nextLong(), rand.nextInt(), new 
MathContext(precision)))
       case DoubleType => randomNumeric[Double](
         rand, r => longBitsToDouble(r.nextLong()), Seq(Double.MinValue, 
Double.MinPositiveValue,
           Double.MaxValue, Double.PositiveInfinity, Double.NegativeInfinity, 
Double.NaN, 0.0))

http://git-wip-us.apache.org/repos/asf/spark/blob/03377d25/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoinerBitsetSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoinerBitsetSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoinerBitsetSuite.scala
new file mode 100644
index 0000000..76d9d99
--- /dev/null
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoinerBitsetSuite.scala
@@ -0,0 +1,147 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions.codegen
+
+import scala.util.Random
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.catalyst.expressions.UnsafeRow
+import org.apache.spark.sql.types._
+
+/**
+ * A test suite for the bitset portion of the row concatenation.
+ */
+class GenerateUnsafeRowJoinerBitsetSuite extends SparkFunSuite {
+
+  test("bitset concat: boundary size 0, 0") {
+    testBitsets(0, 0)
+  }
+
+  test("bitset concat: boundary size 0, 64") {
+    testBitsets(0, 64)
+  }
+
+  test("bitset concat: boundary size 64, 0") {
+    testBitsets(64, 0)
+  }
+
+  test("bitset concat: boundary size 64, 64") {
+    testBitsets(64, 64)
+  }
+
+  test("bitset concat: boundary size 0, 128") {
+    testBitsets(0, 128)
+  }
+
+  test("bitset concat: boundary size 128, 0") {
+    testBitsets(128, 0)
+  }
+
+  test("bitset concat: boundary size 128, 128") {
+    testBitsets(128, 128)
+  }
+
+  test("bitset concat: single word bitsets") {
+    testBitsets(10, 5)
+  }
+
+  test("bitset concat: first bitset larger than a word") {
+    testBitsets(67, 5)
+  }
+
+  test("bitset concat: second bitset larger than a word") {
+    testBitsets(6, 67)
+  }
+
+  test("bitset concat: no reduction in bitset size") {
+    testBitsets(33, 34)
+  }
+
+  test("bitset concat: two words") {
+    testBitsets(120, 95)
+  }
+
+  test("bitset concat: bitset 65, 128") {
+    testBitsets(65, 128)
+  }
+
+  test("bitset concat: randomized tests") {
+    for (i <- 1 until 20) {
+      val numFields1 = Random.nextInt(1000)
+      val numFields2 = Random.nextInt(1000)
+      testBitsetsOnce(numFields1, numFields2)
+    }
+  }
+
+  private def createUnsafeRow(numFields: Int): UnsafeRow = {
+    val row = new UnsafeRow
+    val sizeInBytes = numFields * 8 + ((numFields + 63) / 64) * 8
+    val buf = new Array[Byte](sizeInBytes)
+    row.pointTo(buf, numFields, sizeInBytes)
+    row
+  }
+
+  private def testBitsets(numFields1: Int, numFields2: Int): Unit = {
+    for (i <- 0 until 5) {
+      testBitsetsOnce(numFields1, numFields2)
+    }
+  }
+
+  private def testBitsetsOnce(numFields1: Int, numFields2: Int): Unit = {
+    info(s"num fields: $numFields1 and $numFields2")
+    val schema1 = StructType(Seq.tabulate(numFields1) { i => 
StructField(s"a_$i", IntegerType) })
+    val schema2 = StructType(Seq.tabulate(numFields2) { i => 
StructField(s"b_$i", IntegerType) })
+
+    val row1 = createUnsafeRow(numFields1)
+    val row2 = createUnsafeRow(numFields2)
+
+    if (numFields1 > 0) {
+      for (i <- 0 until Random.nextInt(numFields1)) {
+        row1.setNullAt(Random.nextInt(numFields1))
+      }
+    }
+    if (numFields2 > 0) {
+      for (i <- 0 until Random.nextInt(numFields2)) {
+        row2.setNullAt(Random.nextInt(numFields2))
+      }
+    }
+
+    val concater = GenerateUnsafeRowJoiner.create(schema1, schema2)
+    val output = concater.join(row1, row2)
+
+    def dumpDebug(): String = {
+      val set1 = Seq.tabulate(numFields1) { i => if (row1.isNullAt(i)) "1" 
else "0" }
+      val set2 = Seq.tabulate(numFields2) { i => if (row2.isNullAt(i)) "1" 
else "0" }
+      val out = Seq.tabulate(numFields1 + numFields2) { i => if 
(output.isNullAt(i)) "1" else "0" }
+
+      s"""
+         |input1: ${set1.mkString}
+         |input2: ${set2.mkString}
+         |output: ${out.mkString}
+       """.stripMargin
+    }
+
+    for (i <- 0 until (numFields1 + numFields2)) {
+      if (i < numFields1) {
+        assert(output.isNullAt(i) === row1.isNullAt(i), dumpDebug())
+      } else {
+        assert(output.isNullAt(i) === row2.isNullAt(i - numFields1), 
dumpDebug())
+      }
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/03377d25/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoinerSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoinerSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoinerSuite.scala
new file mode 100644
index 0000000..59729e7
--- /dev/null
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoinerSuite.scala
@@ -0,0 +1,114 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions.codegen
+
+import scala.util.Random
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.RandomDataGenerator
+import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
+import org.apache.spark.sql.catalyst.expressions.UnsafeProjection
+import org.apache.spark.sql.types._
+
+/**
+ * Test suite for [[GenerateUnsafeRowJoiner]].
+ *
+ * There is also a separate [[GenerateUnsafeRowJoinerBitsetSuite]] that tests 
specifically
+ * concatenation for the bitset portion, since that is the hardest one to get 
right.
+ */
+class GenerateUnsafeRowJoinerSuite extends SparkFunSuite {
+
+  private val fixed = Seq(IntegerType)
+  private val variable = Seq(IntegerType, StringType)
+
+  test("simple fixed width types") {
+    testConcat(0, 0, fixed)
+    testConcat(0, 1, fixed)
+    testConcat(1, 0, fixed)
+    testConcat(64, 0, fixed)
+    testConcat(0, 64, fixed)
+    testConcat(64, 64, fixed)
+  }
+
+  test("randomized fix width types") {
+    for (i <- 0 until 20) {
+      testConcatOnce(Random.nextInt(100), Random.nextInt(100), fixed)
+    }
+  }
+
+  test("simple variable width types") {
+    testConcat(0, 0, variable)
+    testConcat(0, 1, variable)
+    testConcat(1, 0, variable)
+    testConcat(64, 0, variable)
+    testConcat(0, 64, variable)
+    testConcat(64, 64, variable)
+  }
+
+  test("randomized variable width types") {
+    for (i <- 0 until 10) {
+      testConcatOnce(Random.nextInt(100), Random.nextInt(100), variable)
+    }
+  }
+
+  private def testConcat(numFields1: Int, numFields2: Int, candidateTypes: 
Seq[DataType]): Unit = {
+    for (i <- 0 until 10) {
+      testConcatOnce(numFields1, numFields2, candidateTypes)
+    }
+  }
+
+  private def testConcatOnce(numFields1: Int, numFields2: Int, candidateTypes: 
Seq[DataType]) {
+    info(s"schema size $numFields1, $numFields2")
+    val schema1 = RandomDataGenerator.randomSchema(numFields1, candidateTypes)
+    val schema2 = RandomDataGenerator.randomSchema(numFields2, candidateTypes)
+
+    // Create the converters needed to convert from external row to internal 
row and to UnsafeRows.
+    val internalConverter1 = 
CatalystTypeConverters.createToCatalystConverter(schema1)
+    val internalConverter2 = 
CatalystTypeConverters.createToCatalystConverter(schema2)
+    val converter1 = UnsafeProjection.create(schema1)
+    val converter2 = UnsafeProjection.create(schema2)
+
+    // Create the input rows, convert them into UnsafeRows.
+    val extRow1 = RandomDataGenerator.forType(schema1, nullable = 
false).get.apply()
+    val extRow2 = RandomDataGenerator.forType(schema2, nullable = 
false).get.apply()
+    val row1 = 
converter1.apply(internalConverter1.apply(extRow1).asInstanceOf[InternalRow])
+    val row2 = 
converter2.apply(internalConverter2.apply(extRow2).asInstanceOf[InternalRow])
+
+    // Run the joiner.
+    val mergedSchema = StructType(schema1 ++ schema2)
+    val concater = GenerateUnsafeRowJoiner.create(schema1, schema2)
+    val output = concater.join(row1, row2)
+
+    // Test everything equals ...
+    for (i <- mergedSchema.indices) {
+      if (i < schema1.size) {
+        assert(output.isNullAt(i) === row1.isNullAt(i))
+        if (!output.isNullAt(i)) {
+          assert(output.get(i, mergedSchema(i).dataType) === row1.get(i, 
mergedSchema(i).dataType))
+        }
+      } else {
+        assert(output.isNullAt(i) === row2.isNullAt(i - schema1.size))
+        if (!output.isNullAt(i)) {
+          assert(output.get(i, mergedSchema(i).dataType) ===
+            row2.get(i - schema1.size, mergedSchema(i).dataType))
+        }
+      }
+    }
+  }
+
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/03377d25/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java
 
b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java
index 66012e3..08a98cd 100644
--- 
a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java
+++ 
b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java
@@ -73,12 +73,7 @@ public final class UnsafeFixedWidthAggregationMap {
    */
   public static boolean supportsAggregationBufferSchema(StructType schema) {
     for (StructField field: schema.fields()) {
-      if (field.dataType() instanceof DecimalType) {
-        DecimalType dt = (DecimalType) field.dataType();
-        if (dt.precision() > Decimal.MAX_LONG_DIGITS()) {
-          return false;
-        }
-      } else if (!UnsafeRow.settableFieldTypes.contains(field.dataType())) {
+      if (!UnsafeRow.isFixedLength(field.dataType())) {
         return false;
       }
     }

http://git-wip-us.apache.org/repos/asf/spark/blob/03377d25/sql/core/src/test/scala/org/apache/spark/sql/execution/TungstenSortSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/TungstenSortSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/TungstenSortSuite.scala
index 4509635..b3f821e 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/TungstenSortSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/TungstenSortSuite.scala
@@ -26,6 +26,9 @@ import org.apache.spark.sql.catalyst.dsl.expressions._
 import org.apache.spark.sql.test.TestSQLContext
 import org.apache.spark.sql.types._
 
+/**
+ * A test suite that generates randomized data to test the [[TungstenSort]] 
operator.
+ */
 class TungstenSortSuite extends SparkPlanTest with BeforeAndAfterAll {
 
   override def beforeAll(): Unit = {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to