Repository: spark
Updated Branches:
  refs/heads/branch-1.5 c1838e430 -> 23842482f


[SPARK-9620] [SQL] generated UnsafeProjection should support many columns or 
large exressions

Currently, generated UnsafeProjection can reach 64k byte code limit of Java. 
This patch will split the generated expressions into multiple functions, to 
avoid the limitation.

After this patch, we can work well with table that have up to 64k columns (hit 
max number of constants limit in Java), it should be enough in practice.

cc rxin

Author: Davies Liu <dav...@databricks.com>

Closes #8044 from davies/wider_table and squashes the following commits:

9192e6c [Davies Liu] fix generated safe projection
d1ef81a [Davies Liu] fix failed tests
737b3d3 [Davies Liu] Merge branch 'master' of github.com:apache/spark into 
wider_table
ffcd132 [Davies Liu] address comments
1b95be4 [Davies Liu] put the generated class into sql package
77ed72d [Davies Liu] address comments
4518e17 [Davies Liu] Merge branch 'master' of github.com:apache/spark into 
wider_table
75ccd01 [Davies Liu] Merge branch 'master' of github.com:apache/spark into 
wider_table
495e932 [Davies Liu] support wider table with more than 1k columns for 
generated projections

(cherry picked from commit fe2fb7fb7189d183a4273ad27514af4b6b461f26)
Signed-off-by: Reynold Xin <r...@databricks.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/23842482
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/23842482
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/23842482

Branch: refs/heads/branch-1.5
Commit: 23842482f46b45d90ba32ce406675cdb5f88c537
Parents: c1838e4
Author: Davies Liu <dav...@databricks.com>
Authored: Mon Aug 10 13:52:18 2015 -0700
Committer: Reynold Xin <r...@databricks.com>
Committed: Mon Aug 10 13:52:27 2015 -0700

----------------------------------------------------------------------
 .../expressions/codegen/CodeGenerator.scala     |  48 +++++++-
 .../codegen/GenerateMutableProjection.scala     |  43 +------
 .../codegen/GenerateSafeProjection.scala        |  52 ++------
 .../codegen/GenerateUnsafeProjection.scala      | 122 ++++++++++---------
 .../codegen/GenerateUnsafeRowJoiner.scala       |   2 +-
 .../codegen/GeneratedProjectionSuite.scala      |  82 +++++++++++++
 6 files changed, 207 insertions(+), 142 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/23842482/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
index 7b41c9a..c21f4d6 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
@@ -18,6 +18,7 @@
 package org.apache.spark.sql.catalyst.expressions.codegen
 
 import scala.collection.mutable
+import scala.collection.mutable.ArrayBuffer
 import scala.language.existentials
 
 import com.google.common.cache.{CacheBuilder, CacheLoader}
@@ -265,6 +266,45 @@ class CodeGenContext {
   def isPrimitiveType(jt: String): Boolean = primitiveTypes.contains(jt)
 
   def isPrimitiveType(dt: DataType): Boolean = isPrimitiveType(javaType(dt))
+
+  /**
+   * Splits the generated code of expressions into multiple functions, because 
function has
+   * 64kb code size limit in JVM
+   *
+   * @param row the variable name of row that is used by expressions
+   */
+  def splitExpressions(row: String, expressions: Seq[String]): String = {
+    val blocks = new ArrayBuffer[String]()
+    val blockBuilder = new StringBuilder()
+    for (code <- expressions) {
+      // We can't know how many byte code will be generated, so use the number 
of bytes as limit
+      if (blockBuilder.length > 64 * 1000) {
+        blocks.append(blockBuilder.toString())
+        blockBuilder.clear()
+      }
+      blockBuilder.append(code)
+    }
+    blocks.append(blockBuilder.toString())
+
+    if (blocks.length == 1) {
+      // inline execution if only one block
+      blocks.head
+    } else {
+      val apply = freshName("apply")
+      val functions = blocks.zipWithIndex.map { case (body, i) =>
+        val name = s"${apply}_$i"
+        val code = s"""
+           |private void $name(InternalRow $row) {
+           |  $body
+           |}
+         """.stripMargin
+         addNewFunction(name, code)
+         name
+      }
+
+      functions.map(name => s"$name($row);").mkString("\n")
+    }
+  }
 }
 
 /**
@@ -289,15 +329,15 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: 
AnyRef] extends Loggin
   protected def declareMutableStates(ctx: CodeGenContext): String = {
     ctx.mutableStates.map { case (javaType, variableName, _) =>
       s"private $javaType $variableName;"
-    }.mkString
+    }.mkString("\n")
   }
 
   protected def initMutableStates(ctx: CodeGenContext): String = {
-    ctx.mutableStates.map(_._3).mkString
+    ctx.mutableStates.map(_._3).mkString("\n")
   }
 
   protected def declareAddedFunctions(ctx: CodeGenContext): String = {
-    ctx.addedFuntions.map { case (funcName, funcCode) => funcCode }.mkString
+    ctx.addedFuntions.map { case (funcName, funcCode) => funcCode 
}.mkString("\n")
   }
 
   /**
@@ -328,6 +368,8 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: 
AnyRef] extends Loggin
   private[this] def doCompile(code: String): GeneratedClass = {
     val evaluator = new ClassBodyEvaluator()
     evaluator.setParentClassLoader(getClass.getClassLoader)
+    // Cannot be under package codegen, or fail with 
java.lang.InstantiationException
+    
evaluator.setClassName("org.apache.spark.sql.catalyst.expressions.GeneratedClass")
     evaluator.setDefaultImports(Array(
       classOf[PlatformDependent].getName,
       classOf[InternalRow].getName,

http://git-wip-us.apache.org/repos/asf/spark/blob/23842482/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
index ac58423..b4d4df8 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
@@ -40,7 +40,7 @@ object GenerateMutableProjection extends 
CodeGenerator[Seq[Expression], () => Mu
 
   protected def create(expressions: Seq[Expression]): (() => 
MutableProjection) = {
     val ctx = newCodeGenContext()
-    val projectionCode = expressions.zipWithIndex.map {
+    val projectionCodes = expressions.zipWithIndex.map {
       case (NoOp, _) => ""
       case (e, i) =>
         val evaluationCode = e.gen(ctx)
@@ -65,49 +65,21 @@ object GenerateMutableProjection extends 
CodeGenerator[Seq[Expression], () => Mu
           """
         }
     }
-    // collect projections into blocks as function has 64kb codesize limit in 
JVM
-    val projectionBlocks = new ArrayBuffer[String]()
-    val blockBuilder = new StringBuilder()
-    for (projection <- projectionCode) {
-      if (blockBuilder.length > 16 * 1000) {
-        projectionBlocks.append(blockBuilder.toString())
-        blockBuilder.clear()
-      }
-      blockBuilder.append(projection)
-    }
-    projectionBlocks.append(blockBuilder.toString())
-
-    val (projectionFuns, projectionCalls) = {
-      // inline execution if codesize limit was not broken
-      if (projectionBlocks.length == 1) {
-        ("", projectionBlocks.head)
-      } else {
-        (
-          projectionBlocks.zipWithIndex.map { case (body, i) =>
-            s"""
-               |private void apply$i(InternalRow i) {
-               |  $body
-               |}
-             """.stripMargin
-          }.mkString,
-          projectionBlocks.indices.map(i => s"apply$i(i);").mkString("\n")
-        )
-      }
-    }
+    val allProjections = ctx.splitExpressions("i", projectionCodes)
 
     val code = s"""
       public Object generate($exprType[] expr) {
-        return new SpecificProjection(expr);
+        return new SpecificMutableProjection(expr);
       }
 
-      class SpecificProjection extends 
${classOf[BaseMutableProjection].getName} {
+      class SpecificMutableProjection extends 
${classOf[BaseMutableProjection].getName} {
 
         private $exprType[] expressions;
         private $mutableRowType mutableRow;
         ${declareMutableStates(ctx)}
         ${declareAddedFunctions(ctx)}
 
-        public SpecificProjection($exprType[] expr) {
+        public SpecificMutableProjection($exprType[] expr) {
           expressions = expr;
           mutableRow = new $genericMutableRowType(${expressions.size});
           ${initMutableStates(ctx)}
@@ -123,12 +95,9 @@ object GenerateMutableProjection extends 
CodeGenerator[Seq[Expression], () => Mu
           return (InternalRow) mutableRow;
         }
 
-        $projectionFuns
-
         public Object apply(Object _i) {
           InternalRow i = (InternalRow) _i;
-          $projectionCalls
-
+          $allProjections
           return mutableRow;
         }
       }

http://git-wip-us.apache.org/repos/asf/spark/blob/23842482/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala
index ef08ddf..7ad352d 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala
@@ -17,8 +17,6 @@
 
 package org.apache.spark.sql.catalyst.expressions.codegen
 
-import scala.collection.mutable.ArrayBuffer
-
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.aggregate.NoOp
 import org.apache.spark.sql.types._
@@ -43,6 +41,9 @@ object GenerateSafeProjection extends 
CodeGenerator[Seq[Expression], Projection]
     val tmp = ctx.freshName("tmp")
     val output = ctx.freshName("safeRow")
     val values = ctx.freshName("values")
+    // These expressions could be splitted into multiple functions
+    ctx.addMutableState("Object[]", values, s"this.$values = null;")
+
     val rowClass = classOf[GenericInternalRow].getName
 
     val fieldWriters = schema.map(_.dataType).zipWithIndex.map { case (dt, i) 
=>
@@ -53,12 +54,12 @@ object GenerateSafeProjection extends 
CodeGenerator[Seq[Expression], Projection]
           $values[$i] = ${converter.primitive};
         }
       """
-    }.mkString("\n")
-
+    }
+    val allFields = ctx.splitExpressions(tmp, fieldWriters)
     val code = s"""
       final InternalRow $tmp = $input;
-      final Object[] $values = new Object[${schema.length}];
-      $fieldWriters
+      this.$values = new Object[${schema.length}];
+      $allFields
       final InternalRow $output = new $rowClass($values);
     """
 
@@ -128,7 +129,7 @@ object GenerateSafeProjection extends 
CodeGenerator[Seq[Expression], Projection]
 
   protected def create(expressions: Seq[Expression]): Projection = {
     val ctx = newCodeGenContext()
-    val projectionCode = expressions.zipWithIndex.map {
+    val expressionCodes = expressions.zipWithIndex.map {
       case (NoOp, _) => ""
       case (e, i) =>
         val evaluationCode = e.gen(ctx)
@@ -143,36 +144,7 @@ object GenerateSafeProjection extends 
CodeGenerator[Seq[Expression], Projection]
             }
           """
     }
-    // collect projections into blocks as function has 64kb codesize limit in 
JVM
-    val projectionBlocks = new ArrayBuffer[String]()
-    val blockBuilder = new StringBuilder()
-    for (projection <- projectionCode) {
-      if (blockBuilder.length > 16 * 1000) {
-        projectionBlocks.append(blockBuilder.toString())
-        blockBuilder.clear()
-      }
-      blockBuilder.append(projection)
-    }
-    projectionBlocks.append(blockBuilder.toString())
-
-    val (projectionFuns, projectionCalls) = {
-      // inline it if we have only one block
-      if (projectionBlocks.length == 1) {
-        ("", projectionBlocks.head)
-      } else {
-        (
-          projectionBlocks.zipWithIndex.map { case (body, i) =>
-            s"""
-               |private void apply$i(InternalRow i) {
-               |  $body
-               |}
-             """.stripMargin
-          }.mkString,
-          projectionBlocks.indices.map(i => s"apply$i(i);").mkString("\n")
-          )
-      }
-    }
-
+    val allExpressions = ctx.splitExpressions("i", expressionCodes)
     val code = s"""
       public Object generate($exprType[] expr) {
         return new SpecificSafeProjection(expr);
@@ -183,6 +155,7 @@ object GenerateSafeProjection extends 
CodeGenerator[Seq[Expression], Projection]
         private $exprType[] expressions;
         private $mutableRowType mutableRow;
         ${declareMutableStates(ctx)}
+        ${declareAddedFunctions(ctx)}
 
         public SpecificSafeProjection($exprType[] expr) {
           expressions = expr;
@@ -190,12 +163,9 @@ object GenerateSafeProjection extends 
CodeGenerator[Seq[Expression], Projection]
           ${initMutableStates(ctx)}
         }
 
-        $projectionFuns
-
         public Object apply(Object _i) {
           InternalRow i = (InternalRow) _i;
-          $projectionCalls
-
+          $allExpressions
           return mutableRow;
         }
       }

http://git-wip-us.apache.org/repos/asf/spark/blob/23842482/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
index d8912df..29f6a7b 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
@@ -19,7 +19,6 @@ package org.apache.spark.sql.catalyst.expressions.codegen
 
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.types._
-import org.apache.spark.unsafe.PlatformDependent
 
 /**
  * Generates a [[Projection]] that returns an [[UnsafeRow]].
@@ -41,8 +40,6 @@ object GenerateUnsafeProjection extends 
CodeGenerator[Seq[Expression], UnsafePro
   private val ArrayWriter = classOf[UnsafeRowWriters.ArrayWriter].getName
   private val MapWriter = classOf[UnsafeRowWriters.MapWriter].getName
 
-  private val PlatformDependent = classOf[PlatformDependent].getName
-
   /** Returns true iff we support this data type. */
   def canSupport(dataType: DataType): Boolean = dataType match {
     case NullType => true
@@ -56,19 +53,19 @@ object GenerateUnsafeProjection extends 
CodeGenerator[Seq[Expression], UnsafePro
 
   def genAdditionalSize(dt: DataType, ev: GeneratedExpressionCode): String = 
dt match {
     case t: DecimalType if t.precision > Decimal.MAX_LONG_DIGITS =>
-      s" + $DecimalWriter.getSize(${ev.primitive})"
+      s"$DecimalWriter.getSize(${ev.primitive})"
     case StringType =>
-      s" + (${ev.isNull} ? 0 : $StringWriter.getSize(${ev.primitive}))"
+      s"${ev.isNull} ? 0 : $StringWriter.getSize(${ev.primitive})"
     case BinaryType =>
-      s" + (${ev.isNull} ? 0 : $BinaryWriter.getSize(${ev.primitive}))"
+      s"${ev.isNull} ? 0 : $BinaryWriter.getSize(${ev.primitive})"
     case CalendarIntervalType =>
-      s" + (${ev.isNull} ? 0 : 16)"
+      s"${ev.isNull} ? 0 : 16"
     case _: StructType =>
-      s" + (${ev.isNull} ? 0 : $StructWriter.getSize(${ev.primitive}))"
+      s"${ev.isNull} ? 0 : $StructWriter.getSize(${ev.primitive})"
     case _: ArrayType =>
-      s" + (${ev.isNull} ? 0 : $ArrayWriter.getSize(${ev.primitive}))"
+      s"${ev.isNull} ? 0 : $ArrayWriter.getSize(${ev.primitive})"
     case _: MapType =>
-      s" + (${ev.isNull} ? 0 : $MapWriter.getSize(${ev.primitive}))"
+      s"${ev.isNull} ? 0 : $MapWriter.getSize(${ev.primitive})"
     case _ => ""
   }
 
@@ -125,64 +122,69 @@ object GenerateUnsafeProjection extends 
CodeGenerator[Seq[Expression], UnsafePro
    */
   private def createCodeForStruct(
       ctx: CodeGenContext,
+      row: String,
       inputs: Seq[GeneratedExpressionCode],
       inputTypes: Seq[DataType]): GeneratedExpressionCode = {
 
+    val fixedSize = 8 * inputTypes.length + 
UnsafeRow.calculateBitSetWidthInBytes(inputTypes.length)
+
     val output = ctx.freshName("convertedStruct")
-    ctx.addMutableState("UnsafeRow", output, s"$output = new UnsafeRow();")
+    ctx.addMutableState("UnsafeRow", output, s"this.$output = new 
UnsafeRow();")
     val buffer = ctx.freshName("buffer")
-    ctx.addMutableState("byte[]", buffer, s"$buffer = new byte[64];")
-    val numBytes = ctx.freshName("numBytes")
+    ctx.addMutableState("byte[]", buffer, s"this.$buffer = new 
byte[$fixedSize];")
     val cursor = ctx.freshName("cursor")
+    ctx.addMutableState("int", cursor, s"this.$cursor = 0;")
+    val tmp = ctx.freshName("tmpBuffer")
 
-    val convertedFields = inputTypes.zip(inputs).map { case (dt, input) =>
-      createConvertCode(ctx, input, dt)
-    }
-
-    val fixedSize = 8 * inputTypes.length + 
UnsafeRow.calculateBitSetWidthInBytes(inputTypes.length)
-    val additionalSize = inputTypes.zip(convertedFields).map { case (dt, ev) =>
-      genAdditionalSize(dt, ev)
-    }.mkString("")
-
-    val fieldWriters = inputTypes.zip(convertedFields).zipWithIndex.map { case 
((dt, ev), i) =>
-      val update = genFieldWriter(ctx, dt, ev, output, i, cursor)
-      if (dt.isInstanceOf[DecimalType]) {
-        // Can't call setNullAt() for DecimalType
+    val convertedFields = inputTypes.zip(inputs).zipWithIndex.map { case ((dt, 
input), i) =>
+      val ev = createConvertCode(ctx, input, dt)
+      val growBuffer = if (!UnsafeRow.isFixedLength(dt)) {
+        val numBytes = ctx.freshName("numBytes")
         s"""
+          int $numBytes = $cursor + (${genAdditionalSize(dt, ev)});
+          if ($buffer.length < $numBytes) {
+            // This will not happen frequently, because the buffer is re-used.
+            byte[] $tmp = new byte[$numBytes * 2];
+            PlatformDependent.copyMemory($buffer, 
PlatformDependent.BYTE_ARRAY_OFFSET,
+              $tmp, PlatformDependent.BYTE_ARRAY_OFFSET, $buffer.length);
+            $buffer = $tmp;
+          }
+          $output.pointTo($buffer, PlatformDependent.BYTE_ARRAY_OFFSET,
+            ${inputTypes.length}, $numBytes);
+         """
+      } else {
+        ""
+      }
+      val update = dt match {
+        case dt: DecimalType if dt.precision > Decimal.MAX_LONG_DIGITS =>
+          // Can't call setNullAt() for DecimalType
+          s"""
           if (${ev.isNull}) {
-           $cursor += $DecimalWriter.write($output, $i, $cursor, null);
+            $cursor += $DecimalWriter.write($output, $i, $cursor, null);
           } else {
-           $update;
+            ${genFieldWriter(ctx, dt, ev, output, i, cursor)};
           }
         """
-      } else {
-        s"""
+        case _ =>
+          s"""
           if (${ev.isNull}) {
             $output.setNullAt($i);
           } else {
-            $update;
+            ${genFieldWriter(ctx, dt, ev, output, i, cursor)};
           }
         """
       }
-    }.mkString("\n")
+      s"""
+        ${ev.code}
+        $growBuffer
+        $update
+      """
+    }
 
     val code = s"""
-      ${convertedFields.map(_.code).mkString("\n")}
-
-      final int $numBytes = $fixedSize $additionalSize;
-      if ($numBytes > $buffer.length) {
-        $buffer = new byte[$numBytes];
-      }
-
-      $output.pointTo(
-        $buffer,
-        $PlatformDependent.BYTE_ARRAY_OFFSET,
-        ${inputTypes.length},
-        $numBytes);
-
-      int $cursor = $fixedSize;
-
-      $fieldWriters
+      $cursor = $fixedSize;
+      $output.pointTo($buffer, PlatformDependent.BYTE_ARRAY_OFFSET, 
${inputTypes.length}, $cursor);
+      ${ctx.splitExpressions(row, convertedFields)}
       """
     GeneratedExpressionCode(code, "false", output)
   }
@@ -265,17 +267,17 @@ object GenerateUnsafeProjection extends 
CodeGenerator[Seq[Expression], UnsafePro
         // Should we do word align?
         val elementSize = elementType.defaultSize
         s"""
-          $PlatformDependent.UNSAFE.put${ctx.primitiveTypeName(elementType)}(
+          PlatformDependent.UNSAFE.put${ctx.primitiveTypeName(elementType)}(
             $buffer,
-            $PlatformDependent.BYTE_ARRAY_OFFSET + $cursor,
+            PlatformDependent.BYTE_ARRAY_OFFSET + $cursor,
             ${convertedElement.primitive});
           $cursor += $elementSize;
         """
       case t: DecimalType if t.precision <= Decimal.MAX_LONG_DIGITS =>
         s"""
-          $PlatformDependent.UNSAFE.putLong(
+          PlatformDependent.UNSAFE.putLong(
             $buffer,
-            $PlatformDependent.BYTE_ARRAY_OFFSET + $cursor,
+            PlatformDependent.BYTE_ARRAY_OFFSET + $cursor,
             ${convertedElement.primitive}.toUnscaledLong());
           $cursor += 8;
         """
@@ -284,7 +286,7 @@ object GenerateUnsafeProjection extends 
CodeGenerator[Seq[Expression], UnsafePro
         s"""
           $cursor += $writer.write(
             $buffer,
-            $PlatformDependent.BYTE_ARRAY_OFFSET + $cursor,
+            PlatformDependent.BYTE_ARRAY_OFFSET + $cursor,
             $elements[$index]);
         """
     }
@@ -318,14 +320,14 @@ object GenerateUnsafeProjection extends 
CodeGenerator[Seq[Expression], UnsafePro
           for (int $index = 0; $index < $numElements; $index++) {
             if ($checkNull) {
               // If element is null, write the negative value address into 
offset region.
-              $PlatformDependent.UNSAFE.putInt(
+              PlatformDependent.UNSAFE.putInt(
                 $buffer,
-                $PlatformDependent.BYTE_ARRAY_OFFSET + 4 * $index,
+                PlatformDependent.BYTE_ARRAY_OFFSET + 4 * $index,
                 -$cursor);
             } else {
-              $PlatformDependent.UNSAFE.putInt(
+              PlatformDependent.UNSAFE.putInt(
                 $buffer,
-                $PlatformDependent.BYTE_ARRAY_OFFSET + 4 * $index,
+                PlatformDependent.BYTE_ARRAY_OFFSET + 4 * $index,
                 $cursor);
 
               $writeElement
@@ -334,7 +336,7 @@ object GenerateUnsafeProjection extends 
CodeGenerator[Seq[Expression], UnsafePro
 
           $output.pointTo(
             $buffer,
-            $PlatformDependent.BYTE_ARRAY_OFFSET,
+            PlatformDependent.BYTE_ARRAY_OFFSET,
             $numElements,
             $numBytes);
         }
@@ -400,7 +402,7 @@ object GenerateUnsafeProjection extends 
CodeGenerator[Seq[Expression], UnsafePro
         val fieldIsNull = s"$tmp.isNullAt($i)"
         GeneratedExpressionCode("", fieldIsNull, getFieldCode)
       }
-      val converter = createCodeForStruct(ctx, fieldEvals, fieldTypes)
+      val converter = createCodeForStruct(ctx, tmp, fieldEvals, fieldTypes)
       val code = s"""
         ${input.code}
          UnsafeRow $output = null;
@@ -427,7 +429,7 @@ object GenerateUnsafeProjection extends 
CodeGenerator[Seq[Expression], UnsafePro
   def createCode(ctx: CodeGenContext, expressions: Seq[Expression]): 
GeneratedExpressionCode = {
     val exprEvals = expressions.map(e => e.gen(ctx))
     val exprTypes = expressions.map(_.dataType)
-    createCodeForStruct(ctx, exprEvals, exprTypes)
+    createCodeForStruct(ctx, "i", exprEvals, exprTypes)
   }
 
   protected def canonicalize(in: Seq[Expression]): Seq[Expression] =

http://git-wip-us.apache.org/repos/asf/spark/blob/23842482/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala
index 30b51dd..8aaa5b4 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala
@@ -155,7 +155,7 @@ object GenerateUnsafeRowJoiner extends 
CodeGenerator[(StructType, StructType), U
            |$putLong(buf, $cursor, $getLong(buf, $cursor) + ($shift << 32));
          """.stripMargin
       }
-    }.mkString
+    }.mkString("\n")
 
     // ------------------------ Finally, put everything together  
--------------------------- //
     val code = s"""

http://git-wip-us.apache.org/repos/asf/spark/blob/23842482/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratedProjectionSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratedProjectionSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratedProjectionSuite.scala
new file mode 100644
index 0000000..8c7ee87
--- /dev/null
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratedProjectionSuite.scala
@@ -0,0 +1,82 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions.codegen
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.types.{StringType, IntegerType, StructField, 
StructType}
+import org.apache.spark.unsafe.types.UTF8String
+
+/**
+ * A test suite for generated projections
+ */
+class GeneratedProjectionSuite extends SparkFunSuite {
+
+  test("generated projections on wider table") {
+    val N = 1000
+    val wideRow1 = new GenericInternalRow((1 to N).toArray[Any])
+    val schema1 = StructType((1 to N).map(i => StructField("", IntegerType)))
+    val wideRow2 = new GenericInternalRow(
+      (1 to N).map(i => UTF8String.fromString(i.toString)).toArray[Any])
+    val schema2 = StructType((1 to N).map(i => StructField("", StringType)))
+    val joined = new JoinedRow(wideRow1, wideRow2)
+    val joinedSchema = StructType(schema1 ++ schema2)
+    val nested = new JoinedRow(InternalRow(joined, joined), joined)
+    val nestedSchema = StructType(
+      Seq(StructField("", joinedSchema), StructField("", joinedSchema)) ++ 
joinedSchema)
+
+    // test generated UnsafeProjection
+    val unsafeProj = UnsafeProjection.create(nestedSchema)
+    val unsafe: UnsafeRow = unsafeProj(nested)
+    (0 until N).foreach { i =>
+      val s = UTF8String.fromString((i + 1).toString)
+      assert(i + 1 === unsafe.getInt(i + 2))
+      assert(s === unsafe.getUTF8String(i + 2 + N))
+      assert(i + 1 === unsafe.getStruct(0, N * 2).getInt(i))
+      assert(s === unsafe.getStruct(0, N * 2).getUTF8String(i + N))
+      assert(i + 1 === unsafe.getStruct(1, N * 2).getInt(i))
+      assert(s === unsafe.getStruct(1, N * 2).getUTF8String(i + N))
+    }
+
+    // test generated SafeProjection
+    val safeProj = FromUnsafeProjection(nestedSchema)
+    val result = safeProj(unsafe)
+    // Can't compare GenericInternalRow with JoinedRow directly
+    (0 until N).foreach { i =>
+      val r = i + 1
+      val s = UTF8String.fromString((i + 1).toString)
+      assert(r === result.getInt(i + 2))
+      assert(s === result.getUTF8String(i + 2 + N))
+      assert(r === result.getStruct(0, N * 2).getInt(i))
+      assert(s === result.getStruct(0, N * 2).getUTF8String(i + N))
+      assert(r === result.getStruct(1, N * 2).getInt(i))
+      assert(s === result.getStruct(1, N * 2).getUTF8String(i + N))
+    }
+
+    // test generated MutableProjection
+    val exprs = nestedSchema.fields.zipWithIndex.map { case (f, i) =>
+      BoundReference(i, f.dataType, true)
+    }
+    val mutableProj = GenerateMutableProjection.generate(exprs)()
+    val row1 = mutableProj(result)
+    assert(result === row1)
+    val row2 = mutableProj(result)
+    assert(result === row2)
+  }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to