Repository: spark
Updated Branches:
  refs/heads/master 137f47865 -> 191bf2689


[SPARK-9518] [SQL] cleanup generated UnsafeRowJoiner and fix bug

Currently, when copy the bitsets, we didn't consider that the row1 may not sit 
in the beginning of byte array.

cc rxin

Author: Davies Liu <dav...@databricks.com>

Closes #7892 from davies/clean_join and squashes the following commits:

14cce9e [Davies Liu] cleanup generated UnsafeRowJoiner and fix bug


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/191bf268
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/191bf268
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/191bf268

Branch: refs/heads/master
Commit: 191bf2689d127a9dd328b9cc517362fd51eaed3d
Parents: 137f478
Author: Davies Liu <dav...@databricks.com>
Authored: Mon Aug 3 04:23:26 2015 -0700
Committer: Reynold Xin <r...@databricks.com>
Committed: Mon Aug 3 04:23:26 2015 -0700

----------------------------------------------------------------------
 .../codegen/GenerateUnsafeRowJoiner.scala       | 102 ++++++-------------
 .../GenerateUnsafeRowJoinerBitsetSuite.scala    |   7 +-
 2 files changed, 37 insertions(+), 72 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/191bf268/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala
index 645eb48..5f8a6f8 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala
@@ -40,10 +40,6 @@ abstract class UnsafeRowJoiner {
  */
 object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), 
UnsafeRowJoiner] {
 
-  def dump(word: Long): String = {
-    Seq.tabulate(64) { i => if ((word >> i) % 2 == 0) "0" else "1" 
}.reverse.mkString
-  }
-
   override protected def create(in: (StructType, StructType)): UnsafeRowJoiner 
= {
     create(in._1, in._2)
   }
@@ -56,76 +52,45 @@ object GenerateUnsafeRowJoiner extends 
CodeGenerator[(StructType, StructType), U
   }
 
   def create(schema1: StructType, schema2: StructType): UnsafeRowJoiner = {
-    val ctx = newCodeGenContext()
     val offset = PlatformDependent.BYTE_ARRAY_OFFSET
+    val getLong = "PlatformDependent.UNSAFE.getLong"
+    val putLong = "PlatformDependent.UNSAFE.putLong"
 
     val bitset1Words = (schema1.size + 63) / 64
     val bitset2Words = (schema2.size + 63) / 64
     val outputBitsetWords = (schema1.size + schema2.size + 63) / 64
     val bitset1Remainder = schema1.size % 64
-    val bitset2Remainder = schema2.size % 64
 
     // The number of words we can reduce when we concat two rows together.
     // The only reduction comes from merging the bitset portion of the two 
rows, saving 1 word.
     val sizeReduction = bitset1Words + bitset2Words - outputBitsetWords
 
-    // --------------------- copy bitset from row 1 ----------------------- //
-    val copyBitset1 = Seq.tabulate(bitset1Words) { i =>
-      s"""
-         |PlatformDependent.UNSAFE.putLong(buf, ${offset + i * 8},
-         |  PlatformDependent.UNSAFE.getLong(obj1, ${offset + i * 8}));
-       """.stripMargin
-    }.mkString
-
-
-    // --------------------- copy bitset from row 2 ----------------------- //
-    var copyBitset2 = ""
-    if (bitset1Remainder == 0) {
-      copyBitset2 += Seq.tabulate(bitset2Words) { i =>
-        s"""
-           |PlatformDependent.UNSAFE.putLong(buf, ${offset + (bitset1Words + 
i) * 8},
-           |  PlatformDependent.UNSAFE.getLong(obj2, ${offset + i * 8}));
-         """.stripMargin
-      }.mkString
-    } else {
-      copyBitset2 = Seq.tabulate(bitset2Words) { i =>
-        s"""
-           |long bs2w$i = PlatformDependent.UNSAFE.getLong(obj2, ${offset + i 
* 8});
-           |long bs2w${i}p1 = (bs2w$i << $bitset1Remainder) & ~((1L << 
$bitset1Remainder) - 1);
-           |long bs2w${i}p2 = (bs2w$i >>> ${64 - bitset1Remainder});
-         """.stripMargin
-      }.mkString
-
-      copyBitset2 += Seq.tabulate(bitset2Words) { i =>
-        val currentOffset = offset + (bitset1Words + i - 1) * 8
-        if (i == 0) {
-          if (bitset1Words > 0) {
-            s"""
-               |PlatformDependent.UNSAFE.putLong(buf, $currentOffset,
-               |  bs2w${i}p1 | PlatformDependent.UNSAFE.getLong(obj1, 
$currentOffset));
-            """.stripMargin
-          } else {
-            s"""
-               |PlatformDependent.UNSAFE.putLong(buf, $currentOffset + 8, 
bs2w${i}p1);
-            """.stripMargin
-          }
+    // --------------------- copy bitset from row 1 and row 2 
--------------------------- //
+    val copyBitset = Seq.tabulate(outputBitsetWords) { i =>
+      val bits = if (bitset1Remainder > 0) {
+        if (i < bitset1Words - 1) {
+          s"$getLong(obj1, offset1 + ${i * 8})"
+        } else if (i == bitset1Words - 1) {
+          // combine last work of bitset1 and first word of bitset2
+          s"$getLong(obj1, offset1 + ${i * 8}) | ($getLong(obj2, offset2) << 
$bitset1Remainder)"
+        } else if (i - bitset1Words < bitset2Words - 1) {
+          // combine next two words of bitset2
+          s"($getLong(obj2, offset2 + ${(i - bitset1Words) * 8}) >>> (64 - 
$bitset1Remainder))" +
+            s"| ($getLong(obj2, offset2 + ${(i - bitset1Words + 1) * 8}) << 
$bitset1Remainder)"
+        } else {
+          // last word of bitset2
+          s"$getLong(obj2, offset2 + ${(i - bitset1Words) * 8}) >>> (64 - 
$bitset1Remainder)"
+        }
+      } else {
+        // they are aligned by word
+        if (i < bitset1Words) {
+          s"$getLong(obj1, offset1 + ${i * 8})"
         } else {
-          s"""
-             |PlatformDependent.UNSAFE.putLong(buf, $currentOffset, bs2w${i}p1 
| bs2w${i - 1}p2);
-          """.stripMargin
+          s"$getLong(obj2, offset2 + ${(i - bitset1Words) * 8})"
         }
-      }.mkString("\n")
-
-      if (bitset2Words > 0 &&
-        (bitset2Remainder == 0 || bitset2Remainder > (64 - bitset1Remainder))) 
{
-        val lastWord = bitset2Words - 1
-        copyBitset2 +=
-          s"""
-             |PlatformDependent.UNSAFE.putLong(buf, ${offset + 
(outputBitsetWords - 1) * 8},
-             |  bs2w${lastWord}p2);
-          """.stripMargin
       }
-    }
+      s"$putLong(buf, ${offset + i * 8}, $bits);"
+    }.mkString("\n")
 
     // --------------------- copy fixed length portion from row 1 
----------------------- //
     var cursor = offset + outputBitsetWords * 8
@@ -149,10 +114,10 @@ object GenerateUnsafeRowJoiner extends 
CodeGenerator[(StructType, StructType), U
     cursor += schema2.size * 8
 
     // --------------------- copy variable length portion from row 1 
----------------------- //
+    val numBytesBitsetAndFixedRow1 = (bitset1Words + schema1.size) * 8
     val copyVariableLengthRow1 = s"""
        |// Copy variable length data for row1
-       |long numBytesBitsetAndFixedRow1 = ${(bitset1Words + schema1.size) * 8};
-       |long numBytesVariableRow1 = row1.getSizeInBytes() - 
numBytesBitsetAndFixedRow1;
+       |long numBytesVariableRow1 = row1.getSizeInBytes() - 
$numBytesBitsetAndFixedRow1;
        |PlatformDependent.copyMemory(
        |  obj1, offset1 + ${(bitset1Words + schema1.size) * 8},
        |  buf, $cursor,
@@ -160,10 +125,10 @@ object GenerateUnsafeRowJoiner extends 
CodeGenerator[(StructType, StructType), U
      """.stripMargin
 
     // --------------------- copy variable length portion from row 2 
----------------------- //
+    val numBytesBitsetAndFixedRow2 = (bitset2Words + schema2.size) * 8
     val copyVariableLengthRow2 = s"""
        |// Copy variable length data for row2
-       |long numBytesBitsetAndFixedRow2 = ${(bitset2Words + schema2.size) * 8};
-       |long numBytesVariableRow2 = row2.getSizeInBytes() - 
numBytesBitsetAndFixedRow2;
+       |long numBytesVariableRow2 = row2.getSizeInBytes() - 
$numBytesBitsetAndFixedRow2;
        |PlatformDependent.copyMemory(
        |  obj2, offset2 + ${(bitset2Words + schema2.size) * 8},
        |  buf, $cursor + numBytesVariableRow1,
@@ -183,12 +148,11 @@ object GenerateUnsafeRowJoiner extends 
CodeGenerator[(StructType, StructType), U
           if (i < schema1.size) {
             s"${(outputBitsetWords - bitset1Words + schema2.size) * 8}L"
           } else {
-            s"${(outputBitsetWords - bitset2Words + schema1.size) * 8}L + 
numBytesVariableRow1"
+            s"(${(outputBitsetWords - bitset2Words + schema1.size) * 8}L + 
numBytesVariableRow1)"
           }
         val cursor = offset + outputBitsetWords * 8 + i * 8
         s"""
-           |PlatformDependent.UNSAFE.putLong(buf, $cursor,
-           |  PlatformDependent.UNSAFE.getLong(buf, $cursor) + ($shift << 32));
+           |$putLong(buf, $cursor, $getLong(buf, $cursor) + ($shift << 32));
          """.stripMargin
       }
     }.mkString
@@ -217,8 +181,7 @@ object GenerateUnsafeRowJoiner extends 
CodeGenerator[(StructType, StructType), U
        |    final Object obj2 = row2.getBaseObject();
        |    final long offset2 = row2.getBaseOffset();
        |
-       |    $copyBitset1
-       |    $copyBitset2
+       |    $copyBitset
        |    $copyFixedLengthRow1
        |    $copyFixedLengthRow2
        |    $copyVariableLengthRow1
@@ -233,7 +196,6 @@ object GenerateUnsafeRowJoiner extends 
CodeGenerator[(StructType, StructType), U
      """.stripMargin
 
     logDebug(s"SpecificUnsafeRowJoiner($schema1, 
$schema2):\n${CodeFormatter.format(code)}")
-    // println(CodeFormatter.format(code))
 
     val c = compile(code)
     c.generate(Array.empty).asInstanceOf[UnsafeRowJoiner]

http://git-wip-us.apache.org/repos/asf/spark/blob/191bf268/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoinerBitsetSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoinerBitsetSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoinerBitsetSuite.scala
index 76d9d99..718a2ac 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoinerBitsetSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoinerBitsetSuite.scala
@@ -22,6 +22,7 @@ import scala.util.Random
 import org.apache.spark.SparkFunSuite
 import org.apache.spark.sql.catalyst.expressions.UnsafeRow
 import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.PlatformDependent
 
 /**
  * A test suite for the bitset portion of the row concatenation.
@@ -91,8 +92,9 @@ class GenerateUnsafeRowJoinerBitsetSuite extends 
SparkFunSuite {
   private def createUnsafeRow(numFields: Int): UnsafeRow = {
     val row = new UnsafeRow
     val sizeInBytes = numFields * 8 + ((numFields + 63) / 64) * 8
-    val buf = new Array[Byte](sizeInBytes)
-    row.pointTo(buf, numFields, sizeInBytes)
+    val offset = numFields * 8
+    val buf = new Array[Byte](sizeInBytes + offset)
+    row.pointTo(buf, PlatformDependent.BYTE_ARRAY_OFFSET + offset, numFields, 
sizeInBytes)
     row
   }
 
@@ -133,6 +135,7 @@ class GenerateUnsafeRowJoinerBitsetSuite extends 
SparkFunSuite {
          |input1: ${set1.mkString}
          |input2: ${set2.mkString}
          |output: ${out.mkString}
+         |expect: ${set1.mkString}${set2.mkString}
        """.stripMargin
     }
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to