Repository: spark
Updated Branches:
  refs/heads/master f6680cdc5 -> 7e9a9e603


[SPARK-11269][SQL] Java API support & test cases for Dataset

This simply brings https://github.com/apache/spark/pull/9358 up-to-date.

Author: Wenchen Fan <wenc...@databricks.com>
Author: Reynold Xin <r...@databricks.com>

Closes #9528 from rxin/dataset-java.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/7e9a9e60
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/7e9a9e60
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/7e9a9e60

Branch: refs/heads/master
Commit: 7e9a9e603abce8689938bdd62d04b29299644aa4
Parents: f6680cd
Author: Wenchen Fan <wenc...@databricks.com>
Authored: Fri Nov 6 15:37:07 2015 -0800
Committer: Reynold Xin <r...@databricks.com>
Committed: Fri Nov 6 15:37:07 2015 -0800

----------------------------------------------------------------------
 .../spark/sql/catalyst/encoders/Encoder.scala   | 123 ++++++-
 .../sql/catalyst/expressions/objects.scala      |  21 ++
 .../scala/org/apache/spark/sql/Dataset.scala    | 126 ++++++-
 .../org/apache/spark/sql/DatasetHolder.scala    |   6 +-
 .../org/apache/spark/sql/GroupedDataset.scala   |  17 +
 .../scala/org/apache/spark/sql/SQLContext.scala |   4 +
 .../org/apache/spark/sql/JavaDatasetSuite.java  | 357 +++++++++++++++++++
 .../spark/sql/DatasetPrimitiveSuite.scala       |   2 +-
 8 files changed, 644 insertions(+), 12 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/7e9a9e60/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/Encoder.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/Encoder.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/Encoder.scala
index 329a132..f05e1828 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/Encoder.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/Encoder.scala
@@ -17,11 +17,11 @@
 
 package org.apache.spark.sql.catalyst.encoders
 
-
-
 import scala.reflect.ClassTag
 
-import org.apache.spark.sql.types.StructType
+import org.apache.spark.util.Utils
+import org.apache.spark.sql.types.{DataType, ObjectType, StructField, 
StructType}
+import org.apache.spark.sql.catalyst.expressions._
 
 /**
  * Used to convert a JVM object of type `T` to and from the internal Spark SQL 
representation.
@@ -37,3 +37,120 @@ trait Encoder[T] extends Serializable {
   /** A ClassTag that can be used to construct and Array to contain a 
collection of `T`. */
   def clsTag: ClassTag[T]
 }
+
+object Encoder {
+  import scala.reflect.runtime.universe._
+
+  def BOOLEAN: Encoder[java.lang.Boolean] = ExpressionEncoder(flat = true)
+  def BYTE: Encoder[java.lang.Byte] = ExpressionEncoder(flat = true)
+  def SHORT: Encoder[java.lang.Short] = ExpressionEncoder(flat = true)
+  def INT: Encoder[java.lang.Integer] = ExpressionEncoder(flat = true)
+  def LONG: Encoder[java.lang.Long] = ExpressionEncoder(flat = true)
+  def FLOAT: Encoder[java.lang.Float] = ExpressionEncoder(flat = true)
+  def DOUBLE: Encoder[java.lang.Double] = ExpressionEncoder(flat = true)
+  def STRING: Encoder[java.lang.String] = ExpressionEncoder(flat = true)
+
+  def tuple[T1, T2](enc1: Encoder[T1], enc2: Encoder[T2]): Encoder[(T1, T2)] = 
{
+    tuple(Seq(enc1, enc2).map(_.asInstanceOf[ExpressionEncoder[_]]))
+      .asInstanceOf[ExpressionEncoder[(T1, T2)]]
+  }
+
+  def tuple[T1, T2, T3](
+      enc1: Encoder[T1],
+      enc2: Encoder[T2],
+      enc3: Encoder[T3]): Encoder[(T1, T2, T3)] = {
+    tuple(Seq(enc1, enc2, enc3).map(_.asInstanceOf[ExpressionEncoder[_]]))
+      .asInstanceOf[ExpressionEncoder[(T1, T2, T3)]]
+  }
+
+  def tuple[T1, T2, T3, T4](
+      enc1: Encoder[T1],
+      enc2: Encoder[T2],
+      enc3: Encoder[T3],
+      enc4: Encoder[T4]): Encoder[(T1, T2, T3, T4)] = {
+    tuple(Seq(enc1, enc2, enc3, 
enc4).map(_.asInstanceOf[ExpressionEncoder[_]]))
+      .asInstanceOf[ExpressionEncoder[(T1, T2, T3, T4)]]
+  }
+
+  def tuple[T1, T2, T3, T4, T5](
+      enc1: Encoder[T1],
+      enc2: Encoder[T2],
+      enc3: Encoder[T3],
+      enc4: Encoder[T4],
+      enc5: Encoder[T5]): Encoder[(T1, T2, T3, T4, T5)] = {
+    tuple(Seq(enc1, enc2, enc3, enc4, 
enc5).map(_.asInstanceOf[ExpressionEncoder[_]]))
+      .asInstanceOf[ExpressionEncoder[(T1, T2, T3, T4, T5)]]
+  }
+
+  private def tuple(encoders: Seq[ExpressionEncoder[_]]): ExpressionEncoder[_] 
= {
+    assert(encoders.length > 1)
+    // make sure all encoders are resolved, i.e. `Attribute` has been resolved 
to `BoundReference`.
+    
assert(encoders.forall(_.constructExpression.find(_.isInstanceOf[Attribute]).isEmpty))
+
+    val schema = StructType(encoders.zipWithIndex.map {
+      case (e, i) => StructField(s"_${i + 1}", if (e.flat) 
e.schema.head.dataType else e.schema)
+    })
+
+    val cls = 
Utils.getContextOrSparkClassLoader.loadClass(s"scala.Tuple${encoders.size}")
+
+    val extractExpressions = encoders.map {
+      case e if e.flat => e.extractExpressions.head
+      case other => CreateStruct(other.extractExpressions)
+    }.zipWithIndex.map { case (expr, index) =>
+      expr.transformUp {
+        case BoundReference(0, t: ObjectType, _) =>
+          Invoke(
+            BoundReference(0, ObjectType(cls), true),
+            s"_${index + 1}",
+            t)
+      }
+    }
+
+    val constructExpressions = encoders.zipWithIndex.map { case (enc, index) =>
+      if (enc.flat) {
+        enc.constructExpression.transform {
+          case b: BoundReference => b.copy(ordinal = index)
+        }
+      } else {
+        enc.constructExpression.transformUp {
+          case BoundReference(ordinal, dt, _) =>
+            GetInternalRowField(BoundReference(index, enc.schema, true), 
ordinal, dt)
+        }
+      }
+    }
+
+    val constructExpression =
+      NewInstance(cls, constructExpressions, false, ObjectType(cls))
+
+    new ExpressionEncoder[Any](
+      schema,
+      false,
+      extractExpressions,
+      constructExpression,
+      ClassTag.apply(cls))
+  }
+
+
+  def typeTagOfTuple2[T1 : TypeTag, T2 : TypeTag]: TypeTag[(T1, T2)] = 
typeTag[(T1, T2)]
+
+  private def getTypeTag[T](c: Class[T]): TypeTag[T] = {
+    import scala.reflect.api
+
+    // val mirror = runtimeMirror(c.getClassLoader)
+    val mirror = rootMirror
+    val sym = mirror.staticClass(c.getName)
+    val tpe = sym.selfType
+    TypeTag(mirror, new api.TypeCreator {
+      def apply[U <: api.Universe with Singleton](m: api.Mirror[U]) =
+        if (m eq mirror) tpe.asInstanceOf[U # Type]
+        else throw new IllegalArgumentException(
+          s"Type tag defined in $mirror cannot be migrated to other mirrors.")
+    })
+  }
+
+  def forTuple2[T1, T2](c1: Class[T1], c2: Class[T2]): Encoder[(T1, T2)] = {
+    implicit val typeTag1 = getTypeTag(c1)
+    implicit val typeTag2 = getTypeTag(c2)
+    ExpressionEncoder[(T1, T2)]()
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/7e9a9e60/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala
index 8185528..4f58464 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala
@@ -491,3 +491,24 @@ case class CreateExternalRow(children: Seq[Expression]) 
extends Expression {
       s"final ${classOf[Row].getName} ${ev.value} = new $rowClass($values);"
   }
 }
+
+case class GetInternalRowField(child: Expression, ordinal: Int, dataType: 
DataType)
+  extends UnaryExpression {
+
+  override def nullable: Boolean = true
+
+  override def eval(input: InternalRow): Any =
+    throw new UnsupportedOperationException("Only code-generated evaluation is 
supported")
+
+  override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): 
String = {
+    val row = child.gen(ctx)
+    s"""
+      ${row.code}
+      final boolean ${ev.isNull} = ${row.isNull};
+      ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
+      if (!${ev.isNull}) {
+        ${ev.value} = ${ctx.getValue(row.value, dataType, ordinal.toString)};
+      }
+    """
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/7e9a9e60/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index 4bca9c3..fecbdac 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -17,9 +17,13 @@
 
 package org.apache.spark.sql
 
+import scala.collection.JavaConverters._
+
 import org.apache.spark.annotation.Experimental
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.catalyst.analysis.UnresolvedAlias
+import org.apache.spark.api.java.function.{Function => JFunction, Function2 => 
JFunction2, _}
+
 import org.apache.spark.sql.catalyst.encoders._
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.plans.Inner
@@ -151,18 +155,37 @@ class Dataset[T] private[sql](
   def transform[U](t: Dataset[T] => Dataset[U]): Dataset[U] = t(this)
 
   /**
+   * (Scala-specific)
    * Returns a new [[Dataset]] that only contains elements where `func` 
returns `true`.
    * @since 1.6.0
    */
   def filter(func: T => Boolean): Dataset[T] = mapPartitions(_.filter(func))
 
   /**
+   * (Java-specific)
+   * Returns a new [[Dataset]] that only contains elements where `func` 
returns `true`.
+   * @since 1.6.0
+   */
+  def filter(func: JFunction[T, java.lang.Boolean]): Dataset[T] =
+    filter(t => func.call(t).booleanValue())
+
+  /**
+   * (Scala-specific)
    * Returns a new [[Dataset]] that contains the result of applying `func` to 
each element.
    * @since 1.6.0
    */
   def map[U : Encoder](func: T => U): Dataset[U] = mapPartitions(_.map(func))
 
   /**
+   * (Java-specific)
+   * Returns a new [[Dataset]] that contains the result of applying `func` to 
each element.
+   * @since 1.6.0
+   */
+  def map[U](func: JFunction[T, U], encoder: Encoder[U]): Dataset[U] =
+    map(t => func.call(t))(encoder)
+
+  /**
+   * (Scala-specific)
    * Returns a new [[Dataset]] that contains the result of applying `func` to 
each element.
    * @since 1.6.0
    */
@@ -177,30 +200,77 @@ class Dataset[T] private[sql](
         logicalPlan))
   }
 
+  /**
+   * (Java-specific)
+   * Returns a new [[Dataset]] that contains the result of applying `func` to 
each element.
+   * @since 1.6.0
+   */
+  def mapPartitions[U](
+      f: FlatMapFunction[java.util.Iterator[T], U],
+      encoder: Encoder[U]): Dataset[U] = {
+    val func: (Iterator[T]) => Iterator[U] = x => 
f.call(x.asJava).iterator().asScala
+    mapPartitions(func)(encoder)
+  }
+
+  /**
+   * (Scala-specific)
+   * Returns a new [[Dataset]] by first applying a function to all elements of 
this [[Dataset]],
+   * and then flattening the results.
+   * @since 1.6.0
+   */
   def flatMap[U : Encoder](func: T => TraversableOnce[U]): Dataset[U] =
     mapPartitions(_.flatMap(func))
 
+  /**
+   * (Java-specific)
+   * Returns a new [[Dataset]] by first applying a function to all elements of 
this [[Dataset]],
+   * and then flattening the results.
+   * @since 1.6.0
+   */
+  def flatMap[U](f: FlatMapFunction[T, U], encoder: Encoder[U]): Dataset[U] = {
+    val func: (T) => Iterable[U] = x => f.call(x).asScala
+    flatMap(func)(encoder)
+  }
+
   /* ************** *
    *  Side effects  *
    * ************** */
 
   /**
+   * (Scala-specific)
    * Runs `func` on each element of this Dataset.
    * @since 1.6.0
    */
   def foreach(func: T => Unit): Unit = rdd.foreach(func)
 
   /**
+   * (Java-specific)
+   * Runs `func` on each element of this Dataset.
+   * @since 1.6.0
+   */
+  def foreach(func: VoidFunction[T]): Unit = foreach(func.call(_))
+
+  /**
+   * (Scala-specific)
    * Runs `func` on each partition of this Dataset.
    * @since 1.6.0
    */
   def foreachPartition(func: Iterator[T] => Unit): Unit = 
rdd.foreachPartition(func)
 
+  /**
+   * (Java-specific)
+   * Runs `func` on each partition of this Dataset.
+   * @since 1.6.0
+   */
+  def foreachPartition(func: VoidFunction[java.util.Iterator[T]]): Unit =
+    foreachPartition(it => func.call(it.asJava))
+
   /* ************* *
    *  Aggregation  *
    * ************* */
 
   /**
+   * (Scala-specific)
    * Reduces the elements of this Dataset using the specified  binary 
function.  The given function
    * must be commutative and associative or the result may be 
non-deterministic.
    * @since 1.6.0
@@ -208,6 +278,15 @@ class Dataset[T] private[sql](
   def reduce(func: (T, T) => T): T = rdd.reduce(func)
 
   /**
+   * (Java-specific)
+   * Reduces the elements of this Dataset using the specified  binary 
function.  The given function
+   * must be commutative and associative or the result may be 
non-deterministic.
+   * @since 1.6.0
+   */
+  def reduce(func: JFunction2[T, T, T]): T = reduce(func.call(_, _))
+
+  /**
+   * (Scala-specific)
    * Aggregates the elements of each partition, and then the results for all 
the partitions, using a
    * given associative and commutative function and a neutral "zero value".
    *
@@ -221,6 +300,15 @@ class Dataset[T] private[sql](
   def fold(zeroValue: T)(op: (T, T) => T): T = rdd.fold(zeroValue)(op)
 
   /**
+   * (Java-specific)
+   * Aggregates the elements of each partition, and then the results for all 
the partitions, using a
+   * given associative and commutative function and a neutral "zero value".
+   * @since 1.6.0
+   */
+  def fold(zeroValue: T, func: JFunction2[T, T, T]): T = 
fold(zeroValue)(func.call(_, _))
+
+  /**
+   * (Scala-specific)
    * Returns a [[GroupedDataset]] where the data is grouped by the given key 
function.
    * @since 1.6.0
    */
@@ -258,6 +346,14 @@ class Dataset[T] private[sql](
       keyAttributes)
   }
 
+  /**
+   * (Java-specific)
+   * Returns a [[GroupedDataset]] where the data is grouped by the given key 
function.
+   * @since 1.6.0
+   */
+  def groupBy[K](f: JFunction[T, K], encoder: Encoder[K]): GroupedDataset[K, 
T] =
+    groupBy(f.call(_))(encoder)
+
   /* ****************** *
    *  Typed Relational  *
    * ****************** */
@@ -267,8 +363,7 @@ class Dataset[T] private[sql](
    * {{{
    *   df.select($"colA", $"colB" + 1)
    * }}}
-   * @group dfops
-   * @since 1.3.0
+   * @since 1.6.0
    */
   // Copied from Dataframe to make sure we don't have invalid overloads.
   @scala.annotation.varargs
@@ -279,7 +374,7 @@ class Dataset[T] private[sql](
    *
    * {{{
    *   val ds = Seq(1, 2, 3).toDS()
-   *   val newDS = ds.select(e[Int]("value + 1"))
+   *   val newDS = ds.select(expr("value + 1").as[Int])
    * }}}
    * @since 1.6.0
    */
@@ -405,6 +500,8 @@ class Dataset[T] private[sql](
    * This type of join can be useful both for preserving type-safety with the 
original object
    * types as well as working with relational data where either side of the 
join has column
    * names in common.
+   *
+   * @since 1.6.0
    */
   def joinWith[U](other: Dataset[U], condition: Column): Dataset[(T, U)] = {
     val left = this.logicalPlan
@@ -438,12 +535,31 @@ class Dataset[T] private[sql](
    *  Gather to Driver Actions  *
    * ************************** */
 
-  /** Returns the first element in this [[Dataset]]. */
+  /**
+   * Returns the first element in this [[Dataset]].
+   * @since 1.6.0
+   */
   def first(): T = rdd.first()
 
-  /** Collects the elements to an Array. */
+  /**
+   * Collects the elements to an Array.
+   * @since 1.6.0
+   */
   def collect(): Array[T] = rdd.collect()
 
+  /**
+   * (Java-specific)
+   * Collects the elements to a Java list.
+   *
+   * Due to the incompatibility problem between Scala and Java, the return 
type of [[collect()]] at
+   * Java side is `java.lang.Object`, which is not easy to use.  Java user can 
use this method
+   * instead and keep the generic type for result.
+   *
+   * @since 1.6.0
+   */
+  def collectAsList(): java.util.List[T] =
+    rdd.collect().toSeq.asJava
+
   /** Returns the first `num` elements of this [[Dataset]] as an Array. */
   def take(num: Int): Array[T] = rdd.take(num)
 

http://git-wip-us.apache.org/repos/asf/spark/blob/7e9a9e60/sql/core/src/main/scala/org/apache/spark/sql/DatasetHolder.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DatasetHolder.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/DatasetHolder.scala
index 45f0098..08097e9 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DatasetHolder.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DatasetHolder.scala
@@ -27,9 +27,9 @@ package org.apache.spark.sql
  *
  * @since 1.6.0
  */
-case class DatasetHolder[T] private[sql](private val df: Dataset[T]) {
+case class DatasetHolder[T] private[sql](private val ds: Dataset[T]) {
 
   // This is declared with parentheses to prevent the Scala compiler from 
treating
-  // `rdd.toDF("1")` as invoking this toDF and then apply on the returned 
DataFrame.
-  def toDS(): Dataset[T] = df
+  // `rdd.toDS("1")` as invoking this toDS and then apply on the returned 
Dataset.
+  def toDS(): Dataset[T] = ds
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/7e9a9e60/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala
index b8fc373..b2803d5 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala
@@ -17,7 +17,11 @@
 
 package org.apache.spark.sql
 
+import java.util.{Iterator => JIterator}
+import scala.collection.JavaConverters._
+
 import org.apache.spark.annotation.Experimental
+import org.apache.spark.api.java.function.{Function2 => JFunction2, Function3 
=> JFunction3, _}
 import org.apache.spark.sql.catalyst.analysis.{UnresolvedAlias, 
UnresolvedAttribute}
 import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, encoderFor, 
Encoder}
 import org.apache.spark.sql.catalyst.expressions.{Expression, NamedExpression, 
Alias, Attribute}
@@ -104,6 +108,12 @@ class GroupedDataset[K, T] private[sql](
       MapGroups(f, groupingAttributes, logicalPlan))
   }
 
+  def mapGroups[U](
+      f: JFunction2[K, JIterator[T], JIterator[U]],
+      encoder: Encoder[U]): Dataset[U] = {
+    mapGroups((key, data) => f.call(key, data.asJava).asScala)(encoder)
+  }
+
   // To ensure valid overloading.
   protected def agg(expr: Column, exprs: Column*): DataFrame =
     groupedData.agg(expr, exprs: _*)
@@ -196,4 +206,11 @@ class GroupedDataset[K, T] private[sql](
         this.logicalPlan,
         other.logicalPlan))
   }
+
+  def cogroup[U, R](
+      other: GroupedDataset[K, U],
+      f: JFunction3[K, JIterator[T], JIterator[U], JIterator[R]],
+      encoder: Encoder[R]): Dataset[R] = {
+    cogroup(other)((key, left, right) => f.call(key, left.asJava, 
right.asJava).asScala)(encoder)
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/7e9a9e60/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
index 5ad3871..5598731 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
@@ -508,6 +508,10 @@ class SQLContext private[sql](
     new Dataset[T](this, plan)
   }
 
+  def createDataset[T : Encoder](data: java.util.List[T]): Dataset[T] = {
+    createDataset(data.asScala)
+  }
+
   /**
    * Creates a DataFrame from an RDD[Row]. User can specify whether the input 
rows should be
    * converted to Catalyst rows.

http://git-wip-us.apache.org/repos/asf/spark/blob/7e9a9e60/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java 
b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
new file mode 100644
index 0000000..a9493d5
--- /dev/null
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
@@ -0,0 +1,357 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package test.org.apache.spark.sql;
+
+import java.io.Serializable;
+import java.util.*;
+
+import scala.Tuple2;
+import scala.Tuple3;
+import scala.Tuple4;
+import scala.Tuple5;
+import org.junit.*;
+
+import org.apache.spark.Accumulator;
+import org.apache.spark.SparkContext;
+import org.apache.spark.api.java.function.*;
+import org.apache.spark.api.java.function.Function;
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.sql.catalyst.encoders.Encoder;
+import org.apache.spark.sql.catalyst.encoders.Encoder$;
+import org.apache.spark.sql.Dataset;
+import org.apache.spark.sql.GroupedDataset;
+import org.apache.spark.sql.test.TestSQLContext;
+
+import static org.apache.spark.sql.functions.*;
+
+public class JavaDatasetSuite implements Serializable {
+  private transient JavaSparkContext jsc;
+  private transient TestSQLContext context;
+  private transient Encoder$ e = Encoder$.MODULE$;
+
+  @Before
+  public void setUp() {
+    // Trigger static initializer of TestData
+    SparkContext sc = new SparkContext("local[*]", "testing");
+    jsc = new JavaSparkContext(sc);
+    context = new TestSQLContext(sc);
+    context.loadTestData();
+  }
+
+  @After
+  public void tearDown() {
+    context.sparkContext().stop();
+    context = null;
+    jsc = null;
+  }
+
+  private <T1, T2> Tuple2<T1, T2> tuple2(T1 t1, T2 t2) {
+    return new Tuple2<T1, T2>(t1, t2);
+  }
+
+  @Test
+  public void testCollect() {
+    List<String> data = Arrays.asList("hello", "world");
+    Dataset<String> ds = context.createDataset(data, e.STRING());
+    String[] collected = (String[]) ds.collect();
+    Assert.assertEquals(Arrays.asList("hello", "world"), 
Arrays.asList(collected));
+  }
+
+  @Test
+  public void testCommonOperation() {
+    List<String> data = Arrays.asList("hello", "world");
+    Dataset<String> ds = context.createDataset(data, e.STRING());
+    Assert.assertEquals("hello", ds.first());
+
+    Dataset<String> filtered = ds.filter(new Function<String, Boolean>() {
+      @Override
+      public Boolean call(String v) throws Exception {
+        return v.startsWith("h");
+      }
+    });
+    Assert.assertEquals(Arrays.asList("hello"), filtered.collectAsList());
+
+
+    Dataset<Integer> mapped = ds.map(new Function<String, Integer>() {
+      @Override
+      public Integer call(String v) throws Exception {
+        return v.length();
+      }
+    }, e.INT());
+    Assert.assertEquals(Arrays.asList(5, 5), mapped.collectAsList());
+
+    Dataset<String> parMapped = ds.mapPartitions(new 
FlatMapFunction<Iterator<String>, String>() {
+      @Override
+      public Iterable<String> call(Iterator<String> it) throws Exception {
+        List<String> ls = new LinkedList<String>();
+        while (it.hasNext()) {
+          ls.add(it.next().toUpperCase());
+        }
+        return ls;
+      }
+    }, e.STRING());
+    Assert.assertEquals(Arrays.asList("HELLO", "WORLD"), 
parMapped.collectAsList());
+
+    Dataset<String> flatMapped = ds.flatMap(new FlatMapFunction<String, 
String>() {
+      @Override
+      public Iterable<String> call(String s) throws Exception {
+        List<String> ls = new LinkedList<String>();
+        for (char c : s.toCharArray()) {
+          ls.add(String.valueOf(c));
+        }
+        return ls;
+      }
+    }, e.STRING());
+    Assert.assertEquals(
+      Arrays.asList("h", "e", "l", "l", "o", "w", "o", "r", "l", "d"),
+      flatMapped.collectAsList());
+  }
+
+  @Test
+  public void testForeach() {
+    final Accumulator<Integer> accum = jsc.accumulator(0);
+    List<String> data = Arrays.asList("a", "b", "c");
+    Dataset<String> ds = context.createDataset(data, e.STRING());
+
+    ds.foreach(new VoidFunction<String>() {
+      @Override
+      public void call(String s) throws Exception {
+        accum.add(1);
+      }
+    });
+    Assert.assertEquals(3, accum.value().intValue());
+  }
+
+  @Test
+  public void testReduce() {
+    List<Integer> data = Arrays.asList(1, 2, 3);
+    Dataset<Integer> ds = context.createDataset(data, e.INT());
+
+    int reduced = ds.reduce(new Function2<Integer, Integer, Integer>() {
+      @Override
+      public Integer call(Integer v1, Integer v2) throws Exception {
+        return v1 + v2;
+      }
+    });
+    Assert.assertEquals(6, reduced);
+
+    int folded = ds.fold(1, new Function2<Integer, Integer, Integer>() {
+      @Override
+      public Integer call(Integer v1, Integer v2) throws Exception {
+        return v1 * v2;
+      }
+    });
+    Assert.assertEquals(6, folded);
+  }
+
+  @Test
+  public void testGroupBy() {
+    List<String> data = Arrays.asList("a", "foo", "bar");
+    Dataset<String> ds = context.createDataset(data, e.STRING());
+    GroupedDataset<Integer, String> grouped = ds.groupBy(new Function<String, 
Integer>() {
+      @Override
+      public Integer call(String v) throws Exception {
+        return v.length();
+      }
+    }, e.INT());
+
+    Dataset<String> mapped = grouped.mapGroups(
+      new Function2<Integer, Iterator<String>, Iterator<String>>() {
+        @Override
+        public Iterator<String> call(Integer key, Iterator<String> data) 
throws Exception {
+          StringBuilder sb = new StringBuilder(key.toString());
+          while (data.hasNext()) {
+            sb.append(data.next());
+          }
+          return Collections.singletonList(sb.toString()).iterator();
+        }
+      },
+      e.STRING());
+
+    Assert.assertEquals(Arrays.asList("1a", "3foobar"), 
mapped.collectAsList());
+
+    List<Integer> data2 = Arrays.asList(2, 6, 10);
+    Dataset<Integer> ds2 = context.createDataset(data2, e.INT());
+    GroupedDataset<Integer, Integer> grouped2 = ds2.groupBy(new 
Function<Integer, Integer>() {
+      @Override
+      public Integer call(Integer v) throws Exception {
+        return v / 2;
+      }
+    }, e.INT());
+
+    Dataset<String> cogrouped = grouped.cogroup(
+      grouped2,
+      new Function3<Integer, Iterator<String>, Iterator<Integer>, 
Iterator<String>>() {
+        @Override
+        public Iterator<String> call(
+            Integer key,
+            Iterator<String> left,
+            Iterator<Integer> right) throws Exception {
+          StringBuilder sb = new StringBuilder(key.toString());
+          while (left.hasNext()) {
+            sb.append(left.next());
+          }
+          sb.append("#");
+          while (right.hasNext()) {
+            sb.append(right.next());
+          }
+          return Collections.singletonList(sb.toString()).iterator();
+        }
+      },
+      e.STRING());
+
+    Assert.assertEquals(Arrays.asList("1a#2", "3foobar#6", "5#10"), 
cogrouped.collectAsList());
+  }
+
+  @Test
+  public void testGroupByColumn() {
+    List<String> data = Arrays.asList("a", "foo", "bar");
+    Dataset<String> ds = context.createDataset(data, e.STRING());
+    GroupedDataset<Integer, String> grouped = 
ds.groupBy(length(col("value"))).asKey(e.INT());
+
+    Dataset<String> mapped = grouped.mapGroups(
+      new Function2<Integer, Iterator<String>, Iterator<String>>() {
+        @Override
+        public Iterator<String> call(Integer key, Iterator<String> data) 
throws Exception {
+          StringBuilder sb = new StringBuilder(key.toString());
+          while (data.hasNext()) {
+            sb.append(data.next());
+          }
+          return Collections.singletonList(sb.toString()).iterator();
+        }
+      },
+      e.STRING());
+
+    Assert.assertEquals(Arrays.asList("1a", "3foobar"), 
mapped.collectAsList());
+  }
+
+  @Test
+  public void testSelect() {
+    List<Integer> data = Arrays.asList(2, 6);
+    Dataset<Integer> ds = context.createDataset(data, e.INT());
+
+    Dataset<Tuple2<Integer, String>> selected = ds.select(
+      expr("value + 1").as(e.INT()),
+      col("value").cast("string").as(e.STRING()));
+
+    Assert.assertEquals(
+      Arrays.asList(tuple2(3, "2"), tuple2(7, "6")),
+      selected.collectAsList());
+  }
+
+  @Test
+  public void testSetOperation() {
+    List<String> data = Arrays.asList("abc", "abc", "xyz");
+    Dataset<String> ds = context.createDataset(data, e.STRING());
+
+    Assert.assertEquals(
+      Arrays.asList("abc", "xyz"),
+      sort(ds.distinct().collectAsList().toArray(new String[0])));
+
+    List<String> data2 = Arrays.asList("xyz", "foo", "foo");
+    Dataset<String> ds2 = context.createDataset(data2, e.STRING());
+
+    Dataset<String> intersected = ds.intersect(ds2);
+    Assert.assertEquals(Arrays.asList("xyz"), intersected.collectAsList());
+
+    Dataset<String> unioned = ds.union(ds2);
+    Assert.assertEquals(
+      Arrays.asList("abc", "abc", "foo", "foo", "xyz", "xyz"),
+      sort(unioned.collectAsList().toArray(new String[0])));
+
+    Dataset<String> subtracted = ds.subtract(ds2);
+    Assert.assertEquals(Arrays.asList("abc", "abc"), 
subtracted.collectAsList());
+  }
+
+  private <T extends Comparable<T>> List<T> sort(T[] data) {
+    Arrays.sort(data);
+    return Arrays.asList(data);
+  }
+
+  @Test
+  public void testJoin() {
+    List<Integer> data = Arrays.asList(1, 2, 3);
+    Dataset<Integer> ds = context.createDataset(data, e.INT()).as("a");
+    List<Integer> data2 = Arrays.asList(2, 3, 4);
+    Dataset<Integer> ds2 = context.createDataset(data2, e.INT()).as("b");
+
+    Dataset<Tuple2<Integer, Integer>> joined =
+      ds.joinWith(ds2, col("a.value").equalTo(col("b.value")));
+    Assert.assertEquals(
+      Arrays.asList(tuple2(2, 2), tuple2(3, 3)),
+      joined.collectAsList());
+  }
+
+  @Test
+  public void testTupleEncoder() {
+    Encoder<Tuple2<Integer, String>> encoder2 = e.tuple(e.INT(), e.STRING());
+    List<Tuple2<Integer, String>> data2 = Arrays.asList(tuple2(1, "a"), 
tuple2(2, "b"));
+    Dataset<Tuple2<Integer, String>> ds2 = context.createDataset(data2, 
encoder2);
+    Assert.assertEquals(data2, ds2.collectAsList());
+
+    Encoder<Tuple3<Integer, Long, String>> encoder3 = e.tuple(e.INT(), 
e.LONG(), e.STRING());
+    List<Tuple3<Integer, Long, String>> data3 =
+      Arrays.asList(new Tuple3<Integer, Long, String>(1, 2L, "a"));
+    Dataset<Tuple3<Integer, Long, String>> ds3 = context.createDataset(data3, 
encoder3);
+    Assert.assertEquals(data3, ds3.collectAsList());
+
+    Encoder<Tuple4<Integer, String, Long, String>> encoder4 =
+      e.tuple(e.INT(), e.STRING(), e.LONG(), e.STRING());
+    List<Tuple4<Integer, String, Long, String>> data4 =
+      Arrays.asList(new Tuple4<Integer, String, Long, String>(1, "b", 2L, 
"a"));
+    Dataset<Tuple4<Integer, String, Long, String>> ds4 = 
context.createDataset(data4, encoder4);
+    Assert.assertEquals(data4, ds4.collectAsList());
+
+    Encoder<Tuple5<Integer, String, Long, String, Boolean>> encoder5 =
+      e.tuple(e.INT(), e.STRING(), e.LONG(), e.STRING(), e.BOOLEAN());
+    List<Tuple5<Integer, String, Long, String, Boolean>> data5 =
+      Arrays.asList(new Tuple5<Integer, String, Long, String, Boolean>(1, "b", 
2L, "a", true));
+    Dataset<Tuple5<Integer, String, Long, String, Boolean>> ds5 =
+      context.createDataset(data5, encoder5);
+    Assert.assertEquals(data5, ds5.collectAsList());
+  }
+
+  @Test
+  public void testNestedTupleEncoder() {
+    // test ((int, string), string)
+    Encoder<Tuple2<Tuple2<Integer, String>, String>> encoder =
+      e.tuple(e.tuple(e.INT(), e.STRING()), e.STRING());
+    List<Tuple2<Tuple2<Integer, String>, String>> data =
+      Arrays.asList(tuple2(tuple2(1, "a"), "a"), tuple2(tuple2(2, "b"), "b"));
+    Dataset<Tuple2<Tuple2<Integer, String>, String>> ds = 
context.createDataset(data, encoder);
+    Assert.assertEquals(data, ds.collectAsList());
+
+    // test (int, (string, string, long))
+    Encoder<Tuple2<Integer, Tuple3<String, String, Long>>> encoder2 =
+      e.tuple(e.INT(), e.tuple(e.STRING(), e.STRING(), e.LONG()));
+    List<Tuple2<Integer, Tuple3<String, String, Long>>> data2 =
+      Arrays.asList(tuple2(1, new Tuple3<String, String, Long>("a", "b", 3L)));
+    Dataset<Tuple2<Integer, Tuple3<String, String, Long>>> ds2 =
+      context.createDataset(data2, encoder2);
+    Assert.assertEquals(data2, ds2.collectAsList());
+
+    // test (int, ((string, long), string))
+    Encoder<Tuple2<Integer, Tuple2<Tuple2<String, Long>, String>>> encoder3 =
+      e.tuple(e.INT(), e.tuple(e.tuple(e.STRING(), e.LONG()), e.STRING()));
+    List<Tuple2<Integer, Tuple2<Tuple2<String, Long>, String>>> data3 =
+      Arrays.asList(tuple2(1, tuple2(tuple2("a", 2L), "b")));
+    Dataset<Tuple2<Integer, Tuple2<Tuple2<String, Long>, String>>> ds3 =
+      context.createDataset(data3, encoder3);
+    Assert.assertEquals(data3, ds3.collectAsList());
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/7e9a9e60/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala
index 3244355..e3b0346 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala
@@ -59,7 +59,7 @@ class DatasetPrimitiveSuite extends QueryTest with 
SharedSQLContext {
   test("foreach") {
     val ds = Seq(1, 2, 3).toDS()
     val acc = sparkContext.accumulator(0)
-    ds.foreach(acc +=)
+    ds.foreach(acc += _)
     assert(acc.value == 6)
   }
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to