Repository: spark
Updated Branches:
  refs/heads/master 2a4f88b6c -> 74d8d3d92


[SPARK-8450] [SQL] [PYSARK] cleanup type converter for Python DataFrame

This PR fixes the converter for Python DataFrame, especially for DecimalType

Closes #7106

Author: Davies Liu <dav...@databricks.com>

Closes #7131 from davies/decimal_python and squashes the following commits:

4d3c234 [Davies Liu] Merge branch 'master' of github.com:apache/spark into 
decimal_python
20531d6 [Davies Liu] Merge branch 'master' of github.com:apache/spark into 
decimal_python
7d73168 [Davies Liu] fix conflit
6cdd86a [Davies Liu] Merge branch 'master' of github.com:apache/spark into 
decimal_python
7104e97 [Davies Liu] improve type infer
9cd5a21 [Davies Liu] run python tests with SPARK_PREPEND_CLASSES
829a05b [Davies Liu] fix UDT in python
c99e8c5 [Davies Liu] fix mima
c46814a [Davies Liu] convert decimal for Python DataFrames


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/74d8d3d9
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/74d8d3d9
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/74d8d3d9

Branch: refs/heads/master
Commit: 74d8d3d928cc9a7386b68588ac89ae042847d146
Parents: 2a4f88b
Author: Davies Liu <dav...@databricks.com>
Authored: Wed Jul 8 18:22:53 2015 -0700
Committer: Davies Liu <davies....@gmail.com>
Committed: Wed Jul 8 18:22:53 2015 -0700

----------------------------------------------------------------------
 .../apache/spark/mllib/linalg/Matrices.scala    | 10 +--
 .../org/apache/spark/mllib/linalg/Vectors.scala | 16 +---
 project/MimaExcludes.scala                      |  5 +-
 python/pyspark/sql/tests.py                     | 13 +++
 python/pyspark/sql/types.py                     |  4 +
 python/run-tests.py                             |  3 +-
 .../scala/org/apache/spark/sql/DataFrame.scala  |  4 +-
 .../scala/org/apache/spark/sql/SQLContext.scala | 28 +-----
 .../apache/spark/sql/execution/pythonUDFs.scala | 95 +++++++++++---------
 9 files changed, 84 insertions(+), 94 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/74d8d3d9/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala 
b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala
index 75e7004..0df0766 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala
@@ -24,9 +24,9 @@ import scala.collection.mutable.{ArrayBuilder => 
MArrayBuilder, HashSet => MHash
 import breeze.linalg.{CSCMatrix => BSM, DenseMatrix => BDM, Matrix => BM}
 
 import org.apache.spark.annotation.DeveloperApi
-import org.apache.spark.sql.Row
-import org.apache.spark.sql.types._
 import org.apache.spark.sql.catalyst.expressions.GenericMutableRow
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.types._
 
 /**
  * Trait for a local matrix.
@@ -147,7 +147,7 @@ private[spark] class MatrixUDT extends 
UserDefinedType[Matrix] {
       ))
   }
 
-  override def serialize(obj: Any): Row = {
+  override def serialize(obj: Any): InternalRow = {
     val row = new GenericMutableRow(7)
     obj match {
       case sm: SparseMatrix =>
@@ -173,9 +173,7 @@ private[spark] class MatrixUDT extends 
UserDefinedType[Matrix] {
 
   override def deserialize(datum: Any): Matrix = {
     datum match {
-      // TODO: something wrong with UDT serialization, should never happen.
-      case m: Matrix => m
-      case row: Row =>
+      case row: InternalRow =>
         require(row.length == 7,
           s"MatrixUDT.deserialize given row with length ${row.length} but 
requires length == 7")
         val tpe = row.getByte(0)

http://git-wip-us.apache.org/repos/asf/spark/blob/74d8d3d9/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala 
b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
index c9c2742..e048b01 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
@@ -28,7 +28,7 @@ import breeze.linalg.{DenseVector => BDV, SparseVector => 
BSV, Vector => BV}
 import org.apache.spark.SparkException
 import org.apache.spark.annotation.DeveloperApi
 import org.apache.spark.mllib.util.NumericParser
-import org.apache.spark.sql.Row
+import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions.GenericMutableRow
 import org.apache.spark.sql.types._
 
@@ -175,7 +175,7 @@ private[spark] class VectorUDT extends 
UserDefinedType[Vector] {
       StructField("values", ArrayType(DoubleType, containsNull = false), 
nullable = true)))
   }
 
-  override def serialize(obj: Any): Row = {
+  override def serialize(obj: Any): InternalRow = {
     obj match {
       case SparseVector(size, indices, values) =>
         val row = new GenericMutableRow(4)
@@ -191,17 +191,12 @@ private[spark] class VectorUDT extends 
UserDefinedType[Vector] {
         row.setNullAt(2)
         row.update(3, values.toSeq)
         row
-      // TODO: There are bugs in UDT serialization because we don't have a 
clear separation between
-      // TODO: internal SQL types and language specific types (including UDT). 
UDT serialize and
-      // TODO: deserialize may get called twice. See SPARK-7186.
-      case row: Row =>
-        row
     }
   }
 
   override def deserialize(datum: Any): Vector = {
     datum match {
-      case row: Row =>
+      case row: InternalRow =>
         require(row.length == 4,
           s"VectorUDT.deserialize given row with length ${row.length} but 
requires length == 4")
         val tpe = row.getByte(0)
@@ -215,11 +210,6 @@ private[spark] class VectorUDT extends 
UserDefinedType[Vector] {
             val values = row.getAs[Iterable[Double]](3).toArray
             new DenseVector(values)
         }
-      // TODO: There are bugs in UDT serialization because we don't have a 
clear separation between
-      // TODO: internal SQL types and language specific types (including UDT). 
UDT serialize and
-      // TODO: deserialize may get called twice. See SPARK-7186.
-      case v: Vector =>
-        v
     }
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/74d8d3d9/project/MimaExcludes.scala
----------------------------------------------------------------------
diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
index 57a86bf..821aadd 100644
--- a/project/MimaExcludes.scala
+++ b/project/MimaExcludes.scala
@@ -63,7 +63,10 @@ object MimaExcludes {
             // SQL execution is considered private.
             excludePackage("org.apache.spark.sql.execution"),
             // Parquet support is considered private.
-            excludePackage("org.apache.spark.sql.parquet")
+            excludePackage("org.apache.spark.sql.parquet"),
+            // local function inside a method
+            ProblemFilters.exclude[MissingMethodProblem](
+              
"org.apache.spark.sql.SQLContext.org$apache$spark$sql$SQLContext$$needsConversion$1")
           ) ++ Seq(
             // SPARK-8479 Add numNonzeros and numActives to Matrix.
             ProblemFilters.exclude[MissingMethodProblem](

http://git-wip-us.apache.org/repos/asf/spark/blob/74d8d3d9/python/pyspark/sql/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index 333378c..66827d4 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -700,6 +700,19 @@ class SQLTests(ReusedPySparkTestCase):
         self.assertTrue(now - now1 < datetime.timedelta(0.001))
         self.assertTrue(now - utcnow1 < datetime.timedelta(0.001))
 
+    def test_decimal(self):
+        from decimal import Decimal
+        schema = StructType([StructField("decimal", DecimalType(10, 5))])
+        df = self.sqlCtx.createDataFrame([(Decimal("3.14159"),)], schema)
+        row = df.select(df.decimal + 1).first()
+        self.assertEqual(row[0], Decimal("4.14159"))
+        tmpPath = tempfile.mkdtemp()
+        shutil.rmtree(tmpPath)
+        df.write.parquet(tmpPath)
+        df2 = self.sqlCtx.read.parquet(tmpPath)
+        row = df2.first()
+        self.assertEqual(row[0], Decimal("3.14159"))
+
     def test_dropna(self):
         schema = StructType([
             StructField("name", StringType(), True),

http://git-wip-us.apache.org/repos/asf/spark/blob/74d8d3d9/python/pyspark/sql/types.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py
index 160df40..7e64cb0 100644
--- a/python/pyspark/sql/types.py
+++ b/python/pyspark/sql/types.py
@@ -1069,6 +1069,10 @@ def _verify_type(obj, dataType):
     if obj is None:
         return
 
+    # StringType can work with any types
+    if isinstance(dataType, StringType):
+        return
+
     if isinstance(dataType, UserDefinedType):
         if not (hasattr(obj, '__UDT__') and obj.__UDT__ == dataType):
             raise ValueError("%r is not an instance of type %r" % (obj, 
dataType))

http://git-wip-us.apache.org/repos/asf/spark/blob/74d8d3d9/python/run-tests.py
----------------------------------------------------------------------
diff --git a/python/run-tests.py b/python/run-tests.py
index 7638854..cc56077 100755
--- a/python/run-tests.py
+++ b/python/run-tests.py
@@ -72,7 +72,8 @@ LOGGER = logging.getLogger()
 
 
 def run_individual_python_test(test_name, pyspark_python):
-    env = {'SPARK_TESTING': '1', 'PYSPARK_PYTHON': which(pyspark_python)}
+    env = dict(os.environ)
+    env.update({'SPARK_TESTING': '1', 'PYSPARK_PYTHON': which(pyspark_python)})
     LOGGER.debug("Starting test(%s): %s", pyspark_python, test_name)
     start_time = time.time()
     try:

http://git-wip-us.apache.org/repos/asf/spark/blob/74d8d3d9/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
index eeefc85..d9f987a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
@@ -1549,8 +1549,8 @@ class DataFrame private[sql](
    * Converts a JavaRDD to a PythonRDD.
    */
   protected[sql] def javaToPython: JavaRDD[Array[Byte]] = {
-    val fieldTypes = schema.fields.map(_.dataType)
-    val jrdd = rdd.map(EvaluatePython.rowToArray(_, fieldTypes)).toJavaRDD()
+    val structType = schema  // capture it for closure
+    val jrdd = queryExecution.toRdd.map(EvaluatePython.toJava(_, 
structType)).toJavaRDD()
     SerDeUtil.javaToPython(jrdd)
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/74d8d3d9/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
index 079f31a..477dea9 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
@@ -1044,33 +1044,7 @@ class SQLContext(@transient val sparkContext: 
SparkContext)
       rdd: RDD[Array[Any]],
       schema: StructType): DataFrame = {
 
-    def needsConversion(dataType: DataType): Boolean = dataType match {
-      case ByteType => true
-      case ShortType => true
-      case LongType => true
-      case FloatType => true
-      case DateType => true
-      case TimestampType => true
-      case StringType => true
-      case ArrayType(_, _) => true
-      case MapType(_, _, _) => true
-      case StructType(_) => true
-      case udt: UserDefinedType[_] => needsConversion(udt.sqlType)
-      case other => false
-    }
-
-    val convertedRdd = if (schema.fields.exists(f => 
needsConversion(f.dataType))) {
-      rdd.map(m => m.zip(schema.fields).map {
-        case (value, field) => EvaluatePython.fromJava(value, field.dataType)
-      })
-    } else {
-      rdd
-    }
-
-    val rowRdd = convertedRdd.mapPartitions { iter =>
-      iter.map { m => new GenericInternalRow(m): InternalRow}
-    }
-
+    val rowRdd = rdd.map(r => EvaluatePython.fromJava(r, 
schema).asInstanceOf[InternalRow])
     DataFrame(this, LogicalRDD(schema.toAttributes, rowRdd)(self))
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/74d8d3d9/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala
index 6946e79..1c8130b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala
@@ -24,20 +24,19 @@ import scala.collection.JavaConverters._
 
 import net.razorvine.pickle.{Pickler, Unpickler}
 
-import org.apache.spark.{Accumulator, Logging => SparkLogging}
 import org.apache.spark.annotation.DeveloperApi
 import org.apache.spark.api.python.{PythonBroadcast, PythonRDD}
 import org.apache.spark.broadcast.Broadcast
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.Row
 import org.apache.spark.sql.catalyst.plans.logical
 import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
 import org.apache.spark.sql.catalyst.rules.Rule
 import org.apache.spark.sql.catalyst.util.DateTimeUtils
 import org.apache.spark.sql.types._
 import org.apache.spark.unsafe.types.UTF8String
+import org.apache.spark.{Accumulator, Logging => SparkLogging}
 
 /**
  * A serialized version of a Python lambda function.  Suitable for use in a 
[[PythonRDD]].
@@ -125,59 +124,86 @@ object EvaluatePython {
     new EvaluatePython(udf, child, AttributeReference("pythonUDF", 
udf.dataType)())
 
   /**
-   * Helper for converting a Scala object to a java suitable for pyspark 
serialization.
+   * Helper for converting from Catalyst type to java type suitable for 
Pyrolite.
    */
   def toJava(obj: Any, dataType: DataType): Any = (obj, dataType) match {
     case (null, _) => null
 
-    case (row: Row, struct: StructType) =>
+    case (row: InternalRow, struct: StructType) =>
       val fields = struct.fields.map(field => field.dataType)
-      row.toSeq.zip(fields).map {
-        case (obj, dataType) => toJava(obj, dataType)
-      }.toArray
+      rowToArray(row, fields)
 
     case (seq: Seq[Any], array: ArrayType) =>
       seq.map(x => toJava(x, array.elementType)).asJava
-    case (list: JList[_], array: ArrayType) =>
-      list.map(x => toJava(x, array.elementType)).asJava
-    case (arr, array: ArrayType) if arr.getClass.isArray =>
-      arr.asInstanceOf[Array[Any]].map(x => toJava(x, array.elementType))
 
     case (obj: Map[_, _], mt: MapType) => obj.map {
       case (k, v) => (toJava(k, mt.keyType), toJava(v, mt.valueType))
     }.asJava
 
-    case (ud, udt: UserDefinedType[_]) => toJava(udt.serialize(ud), 
udt.sqlType)
+    case (ud, udt: UserDefinedType[_]) => toJava(ud, udt.sqlType)
 
     case (date: Int, DateType) => DateTimeUtils.toJavaDate(date)
     case (t: Long, TimestampType) => DateTimeUtils.toJavaTimestamp(t)
+
+    case (d: Decimal, _) => d.toJavaBigDecimal
+
     case (s: UTF8String, StringType) => s.toString
 
-    // Pyrolite can handle Timestamp and Decimal
     case (other, _) => other
   }
 
   /**
    * Convert Row into Java Array (for pickled into Python)
    */
-  def rowToArray(row: Row, fields: Seq[DataType]): Array[Any] = {
+  def rowToArray(row: InternalRow, fields: Seq[DataType]): Array[Any] = {
     // TODO: this is slow!
     row.toSeq.zip(fields).map {case (obj, dt) => toJava(obj, dt)}.toArray
   }
 
-  // Converts value to the type specified by the data type.
-  // Because Python does not have data types for TimestampType, FloatType, 
ShortType, and
-  // ByteType, we need to explicitly convert values in columns of these data 
types to the desired
-  // JVM data types.
+  /**
+   * Converts `obj` to the type specified by the data type, or returns null if 
the type of obj is
+   * unexpected. Because Python doesn't enforce the type.
+   */
   def fromJava(obj: Any, dataType: DataType): Any = (obj, dataType) match {
-    // TODO: We should check nullable
     case (null, _) => null
 
+    case (c: Boolean, BooleanType) => c
+
+    case (c: Int, ByteType) => c.toByte
+    case (c: Long, ByteType) => c.toByte
+
+    case (c: Int, ShortType) => c.toShort
+    case (c: Long, ShortType) => c.toShort
+
+    case (c: Int, IntegerType) => c
+    case (c: Long, IntegerType) => c.toInt
+
+    case (c: Int, LongType) => c.toLong
+    case (c: Long, LongType) => c
+
+    case (c: Double, FloatType) => c.toFloat
+
+    case (c: Double, DoubleType) => c
+
+    case (c: java.math.BigDecimal, dt: DecimalType) => Decimal(c)
+
+    case (c: Int, DateType) => c
+
+    case (c: Long, TimestampType) => c
+
+    case (c: String, StringType) => UTF8String.fromString(c)
+    case (c, StringType) =>
+      // If we get here, c is not a string. Call toString on it.
+      UTF8String.fromString(c.toString)
+
+    case (c: String, BinaryType) => c.getBytes("utf-8")
+    case (c, BinaryType) if c.getClass.isArray && 
c.getClass.getComponentType.getName == "byte" => c
+
     case (c: java.util.List[_], ArrayType(elementType, _)) =>
-      c.map { e => fromJava(e, elementType)}: Seq[Any]
+      c.map { e => fromJava(e, elementType)}.toSeq
 
     case (c, ArrayType(elementType, _)) if c.getClass.isArray =>
-      c.asInstanceOf[Array[_]].map(e => fromJava(e, elementType)): Seq[Any]
+      c.asInstanceOf[Array[_]].map(e => fromJava(e, elementType)).toSeq
 
     case (c: java.util.Map[_, _], MapType(keyType, valueType, _)) => c.map {
       case (key, value) => (fromJava(key, keyType), fromJava(value, valueType))
@@ -188,30 +214,11 @@ object EvaluatePython {
         case (e, f) => fromJava(e, f.dataType)
       })
 
-    case (c: java.util.Calendar, DateType) =>
-      DateTimeUtils.fromJavaDate(new java.sql.Date(c.getTimeInMillis))
-
-    case (c: java.util.Calendar, TimestampType) =>
-      c.getTimeInMillis * 10000L
-    case (t: java.sql.Timestamp, TimestampType) =>
-      DateTimeUtils.fromJavaTimestamp(t)
-
-    case (_, udt: UserDefinedType[_]) =>
-      fromJava(obj, udt.sqlType)
-
-    case (c: Int, ByteType) => c.toByte
-    case (c: Long, ByteType) => c.toByte
-    case (c: Int, ShortType) => c.toShort
-    case (c: Long, ShortType) => c.toShort
-    case (c: Long, IntegerType) => c.toInt
-    case (c: Int, LongType) => c.toLong
-    case (c: Double, FloatType) => c.toFloat
-    case (c: String, StringType) => UTF8String.fromString(c)
-    case (c, StringType) =>
-      // If we get here, c is not a string. Call toString on it.
-      UTF8String.fromString(c.toString)
+    case (_, udt: UserDefinedType[_]) => fromJava(obj, udt.sqlType)
 
-    case (c, _) => c
+    // all other unexpected type should be null, or we will have runtime 
exception
+    // TODO(davies): we could improve this by try to cast the object to 
expected type
+    case (c, _) => null
   }
 }
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to