Repository: spark Updated Branches: refs/heads/branch-1.5 28bb97730 -> 864d5de6d
[SPARK-9119] [SPARK-8359] [SQL] match Decimal.precision/scale with DecimalType Let Decimal carry the correct precision and scale with DecimalType. cc rxin yhuai Author: Davies Liu <dav...@databricks.com> Closes #7925 from davies/decimal_scale and squashes the following commits: e19701a [Davies Liu] some tweaks 57d78d2 [Davies Liu] fix tests 5d5bc69 [Davies Liu] match precision and scale with DecimalType (cherry picked from commit 781c8d71a0a6a86c84048a4f22cb3a7d035a5be2) Signed-off-by: Davies Liu <davies....@gmail.com> Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/864d5de6 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/864d5de6 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/864d5de6 Branch: refs/heads/branch-1.5 Commit: 864d5de6da3110974ddf0fbd216e3ef934a9f034 Parents: 28bb977 Author: Davies Liu <dav...@databricks.com> Authored: Tue Aug 4 23:12:49 2015 -0700 Committer: Davies Liu <davies....@gmail.com> Committed: Tue Aug 4 23:13:03 2015 -0700 ---------------------------------------------------------------------- .../main/scala/org/apache/spark/sql/Row.scala | 4 +++ .../sql/catalyst/CatalystTypeConverters.scala | 21 ++++++----- .../catalyst/analysis/HiveTypeCoercion.scala | 4 +-- .../spark/sql/catalyst/expressions/Cast.scala | 6 ++-- .../sql/catalyst/expressions/arithmetic.scala | 2 +- .../org/apache/spark/sql/types/Decimal.scala | 37 ++++++++++++++++---- .../spark/sql/types/decimal/DecimalSuite.scala | 21 +++++++++++ .../sql/execution/SparkSqlSerializer2.scala | 3 +- .../apache/spark/sql/execution/pythonUDFs.scala | 3 +- .../org/apache/spark/sql/json/InferSchema.scala | 21 ++++++++--- .../apache/spark/sql/json/JacksonParser.scala | 5 ++- .../apache/spark/sql/JavaApplySchemaSuite.java | 2 +- .../columnar/InMemoryColumnarQuerySuite.scala | 2 +- .../org/apache/spark/sql/json/JsonSuite.scala | 26 ++++++-------- .../spark/sql/parquet/ParquetQuerySuite.scala | 13 +++++++ .../hive/execution/ScriptTransformation.scala | 2 +- 16 files changed, 122 insertions(+), 50 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/864d5de6/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala index 9144947..40159aa 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala @@ -417,6 +417,10 @@ trait Row extends Serializable { if (!o2.isInstanceOf[Double] || ! java.lang.Double.isNaN(o2.asInstanceOf[Double])) { return false } + case d1: java.math.BigDecimal if o2.isInstanceOf[java.math.BigDecimal] => + if (d1.compareTo(o2.asInstanceOf[java.math.BigDecimal]) != 0) { + return false + } case _ => if (o1 != o2) { return false } http://git-wip-us.apache.org/repos/asf/spark/blob/864d5de6/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala index c666864..8d0c64e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala @@ -317,18 +317,23 @@ object CatalystTypeConverters { private class DecimalConverter(dataType: DecimalType) extends CatalystTypeConverter[Any, JavaBigDecimal, Decimal] { - override def toCatalystImpl(scalaValue: Any): Decimal = scalaValue match { - case d: BigDecimal => Decimal(d) - case d: JavaBigDecimal => Decimal(d) - case d: Decimal => d + override def toCatalystImpl(scalaValue: Any): Decimal = { + val decimal = scalaValue match { + case d: BigDecimal => Decimal(d) + case d: JavaBigDecimal => Decimal(d) + case d: Decimal => d + } + if (decimal.changePrecision(dataType.precision, dataType.scale)) { + decimal + } else { + null + } } override def toScala(catalystValue: Decimal): JavaBigDecimal = catalystValue.toJavaBigDecimal override def toScalaImpl(row: InternalRow, column: Int): JavaBigDecimal = row.getDecimal(column, dataType.precision, dataType.scale).toJavaBigDecimal } - private object BigDecimalConverter extends DecimalConverter(DecimalType.SYSTEM_DEFAULT) - private abstract class PrimitiveConverter[T] extends CatalystTypeConverter[T, Any, Any] { final override def toScala(catalystValue: Any): Any = catalystValue final override def toCatalystImpl(scalaValue: T): Any = scalaValue @@ -413,8 +418,8 @@ object CatalystTypeConverters { case s: String => StringConverter.toCatalyst(s) case d: Date => DateConverter.toCatalyst(d) case t: Timestamp => TimestampConverter.toCatalyst(t) - case d: BigDecimal => BigDecimalConverter.toCatalyst(d) - case d: JavaBigDecimal => BigDecimalConverter.toCatalyst(d) + case d: BigDecimal => new DecimalConverter(DecimalType(d.precision, d.scale)).toCatalyst(d) + case d: JavaBigDecimal => new DecimalConverter(DecimalType(d.precision, d.scale)).toCatalyst(d) case seq: Seq[Any] => new GenericArrayData(seq.map(convertToCatalyst).toArray) case r: Row => InternalRow(r.toSeq.map(convertToCatalyst): _*) case arr: Array[Any] => new GenericArrayData(arr.map(convertToCatalyst)) http://git-wip-us.apache.org/repos/asf/spark/blob/864d5de6/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index 422d423..490f3dc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -442,8 +442,8 @@ object HiveTypeCoercion { * Changes numeric values to booleans so that expressions like true = 1 can be evaluated. */ object BooleanEquality extends Rule[LogicalPlan] { - private val trueValues = Seq(1.toByte, 1.toShort, 1, 1L, Decimal(1)) - private val falseValues = Seq(0.toByte, 0.toShort, 0, 0L, Decimal(0)) + private val trueValues = Seq(1.toByte, 1.toShort, 1, 1L, Decimal.ONE) + private val falseValues = Seq(0.toByte, 0.toShort, 0, 0L, Decimal.ZERO) private def buildCaseKeyWhen(booleanExpr: Expression, numericExpr: Expression) = { CaseKeyWhen(numericExpr, Seq( http://git-wip-us.apache.org/repos/asf/spark/blob/864d5de6/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 88429bb..39f9970 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -26,8 +26,6 @@ import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} -import scala.collection.mutable - object Cast { @@ -157,7 +155,7 @@ case class Cast(child: Expression, dataType: DataType) case ByteType => buildCast[Byte](_, _ != 0) case DecimalType() => - buildCast[Decimal](_, _ != Decimal(0)) + buildCast[Decimal](_, _ != Decimal.ZERO) case DoubleType => buildCast[Double](_, _ != 0) case FloatType => @@ -311,7 +309,7 @@ case class Cast(child: Expression, dataType: DataType) case _: NumberFormatException => null }) case BooleanType => - buildCast[Boolean](_, b => changePrecision(if (b) Decimal(1) else Decimal(0), target)) + buildCast[Boolean](_, b => changePrecision(if (b) Decimal.ONE else Decimal.ZERO, target)) case DateType => buildCast[Int](_, d => null) // date can't cast to decimal in Hive case TimestampType => http://git-wip-us.apache.org/repos/asf/spark/blob/864d5de6/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 0891b55..5808e3f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -511,6 +511,6 @@ case class Pmod(left: Expression, right: Expression) extends BinaryArithmetic { private def pmod(a: Decimal, n: Decimal): Decimal = { val r = a % n - if (r.compare(Decimal(0)) < 0) {(r + n) % n} else r + if (r.compare(Decimal.ZERO) < 0) {(r + n) % n} else r } } http://git-wip-us.apache.org/repos/asf/spark/blob/864d5de6/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala index c0155ee..624c3f3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.types +import java.math.{RoundingMode, MathContext} + import org.apache.spark.annotation.DeveloperApi /** @@ -28,7 +30,7 @@ import org.apache.spark.annotation.DeveloperApi * - Otherwise, the decimal value is longVal / (10 ** _scale) */ final class Decimal extends Ordered[Decimal] with Serializable { - import org.apache.spark.sql.types.Decimal.{BIG_DEC_ZERO, MAX_LONG_DIGITS, POW_10, ROUNDING_MODE} + import org.apache.spark.sql.types.Decimal._ private var decimalVal: BigDecimal = null private var longVal: Long = 0L @@ -137,9 +139,9 @@ final class Decimal extends Ordered[Decimal] with Serializable { def toBigDecimal: BigDecimal = { if (decimalVal.ne(null)) { - decimalVal + decimalVal(MATH_CONTEXT) } else { - BigDecimal(longVal, _scale) + BigDecimal(longVal, _scale)(MATH_CONTEXT) } } @@ -261,10 +263,23 @@ final class Decimal extends Ordered[Decimal] with Serializable { def isZero: Boolean = if (decimalVal.ne(null)) decimalVal == BIG_DEC_ZERO else longVal == 0 - def + (that: Decimal): Decimal = Decimal(toBigDecimal + that.toBigDecimal) + def + (that: Decimal): Decimal = { + if (decimalVal.eq(null) && that.decimalVal.eq(null) && scale == that.scale) { + Decimal(longVal + that.longVal, Math.max(precision, that.precision), scale) + } else { + Decimal(toBigDecimal + that.toBigDecimal, precision, scale) + } + } - def - (that: Decimal): Decimal = Decimal(toBigDecimal - that.toBigDecimal) + def - (that: Decimal): Decimal = { + if (decimalVal.eq(null) && that.decimalVal.eq(null) && scale == that.scale) { + Decimal(longVal - that.longVal, Math.max(precision, that.precision), scale) + } else { + Decimal(toBigDecimal - that.toBigDecimal, precision, scale) + } + } + // HiveTypeCoercion will take care of the precision, scale of result def * (that: Decimal): Decimal = Decimal(toBigDecimal * that.toBigDecimal) def / (that: Decimal): Decimal = @@ -277,13 +292,13 @@ final class Decimal extends Ordered[Decimal] with Serializable { def unary_- : Decimal = { if (decimalVal.ne(null)) { - Decimal(-decimalVal) + Decimal(-decimalVal, precision, scale) } else { Decimal(-longVal, precision, scale) } } - def abs: Decimal = if (this.compare(Decimal(0)) < 0) this.unary_- else this + def abs: Decimal = if (this.compare(Decimal.ZERO) < 0) this.unary_- else this } object Decimal { @@ -296,6 +311,11 @@ object Decimal { private val BIG_DEC_ZERO = BigDecimal(0) + private val MATH_CONTEXT = new MathContext(DecimalType.MAX_PRECISION, RoundingMode.HALF_UP) + + private[sql] val ZERO = Decimal(0) + private[sql] val ONE = Decimal(1) + def apply(value: Double): Decimal = new Decimal().set(value) def apply(value: Long): Decimal = new Decimal().set(value) @@ -309,6 +329,9 @@ object Decimal { def apply(value: BigDecimal, precision: Int, scale: Int): Decimal = new Decimal().set(value, precision, scale) + def apply(value: java.math.BigDecimal, precision: Int, scale: Int): Decimal = + new Decimal().set(value, precision, scale) + def apply(unscaled: Long, precision: Int, scale: Int): Decimal = new Decimal().set(unscaled, precision, scale) http://git-wip-us.apache.org/repos/asf/spark/blob/864d5de6/sql/catalyst/src/test/scala/org/apache/spark/sql/types/decimal/DecimalSuite.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/decimal/DecimalSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/decimal/DecimalSuite.scala index 1d297be..6921d15 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/decimal/DecimalSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/decimal/DecimalSuite.scala @@ -166,6 +166,27 @@ class DecimalSuite extends SparkFunSuite with PrivateMethodTester { assert(Decimal(100) % Decimal(0) === null) } + // regression test for SPARK-8359 + test("accurate precision after multiplication") { + val decimal = (Decimal(Long.MaxValue, 38, 0) * Decimal(Long.MaxValue, 38, 0)).toJavaBigDecimal + assert(decimal.unscaledValue.toString === "85070591730234615847396907784232501249") + } + + // regression test for SPARK-8677 + test("fix non-terminating decimal expansion problem") { + val decimal = Decimal(1.0, 10, 3) / Decimal(3.0, 10, 3) + // The difference between decimal should not be more than 0.001. + assert(decimal.toDouble - 0.333 < 0.001) + } + + // regression test for SPARK-8800 + test("fix loss of precision/scale when doing division operation") { + val a = Decimal(2) / Decimal(3) + assert(a.toDouble < 1.0 && a.toDouble > 0.6) + val b = Decimal(1) / Decimal(8) + assert(b.toDouble === 0.125) + } + test("set/setOrNull") { assert(new Decimal().set(10L, 10, 0).toUnscaledLong === 10L) assert(new Decimal().set(100L, 10, 0).toUnscaledLong === 100L) http://git-wip-us.apache.org/repos/asf/spark/blob/864d5de6/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala index e5bbd0a..e811f1d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala @@ -412,7 +412,8 @@ private[sql] object SparkSqlSerializer2 { // Then, read the scale. val scale = in.readInt() // Finally, create the Decimal object and set it in the row. - mutableRow.update(i, Decimal(new BigDecimal(unscaledVal, scale))) + mutableRow.update(i, + Decimal(new BigDecimal(unscaledVal, scale), decimal.precision, decimal.scale)) } } i += 1 http://git-wip-us.apache.org/repos/asf/spark/blob/864d5de6/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala index aade2e7..dedc7c4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala @@ -21,7 +21,6 @@ import java.io.OutputStream import java.util.{List => JList, Map => JMap} import scala.collection.JavaConversions._ -import scala.collection.JavaConverters._ import net.razorvine.pickle._ @@ -182,7 +181,7 @@ object EvaluatePython { case (c: Double, DoubleType) => c - case (c: java.math.BigDecimal, dt: DecimalType) => Decimal(c) + case (c: java.math.BigDecimal, dt: DecimalType) => Decimal(c, dt.precision, dt.scale) case (c: Int, DateType) => c http://git-wip-us.apache.org/repos/asf/spark/blob/864d5de6/sql/core/src/main/scala/org/apache/spark/sql/json/InferSchema.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/InferSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/InferSchema.scala index 04ab5e2..ec5668c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/InferSchema.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/json/InferSchema.scala @@ -113,8 +113,12 @@ private[sql] object InferSchema { case INT | LONG => LongType // Since we do not have a data type backed by BigInteger, // when we see a Java BigInteger, we use DecimalType. - case BIG_INTEGER | BIG_DECIMAL => DecimalType.SYSTEM_DEFAULT - case FLOAT | DOUBLE => DoubleType + case BIG_INTEGER | BIG_DECIMAL => + val v = parser.getDecimalValue + DecimalType(v.precision(), v.scale()) + case FLOAT | DOUBLE => + // TODO(davies): Should we use decimal if possible? + DoubleType } case VALUE_TRUE | VALUE_FALSE => BooleanType @@ -171,9 +175,18 @@ private[sql] object InferSchema { // Double support larger range than fixed decimal, DecimalType.Maximum should be enough // in most case, also have better precision. case (DoubleType, t: DecimalType) => - if (t == DecimalType.SYSTEM_DEFAULT) t else DoubleType + DoubleType case (t: DecimalType, DoubleType) => - if (t == DecimalType.SYSTEM_DEFAULT) t else DoubleType + DoubleType + case (t1: DecimalType, t2: DecimalType) => + val scale = math.max(t1.scale, t2.scale) + val range = math.max(t1.precision - t1.scale, t2.precision - t2.scale) + if (range + scale > 38) { + // DecimalType can't support precision > 38 + DoubleType + } else { + DecimalType(range + scale, scale) + } case (StructType(fields1), StructType(fields2)) => val newFields = (fields1 ++ fields2).groupBy(field => field.name).map { http://git-wip-us.apache.org/repos/asf/spark/blob/864d5de6/sql/core/src/main/scala/org/apache/spark/sql/json/JacksonParser.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JacksonParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JacksonParser.scala index bf0448e..f1a66c8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/JacksonParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JacksonParser.scala @@ -84,9 +84,8 @@ private[sql] object JacksonParser { case (VALUE_NUMBER_INT | VALUE_NUMBER_FLOAT, DoubleType) => parser.getDoubleValue - case (VALUE_NUMBER_INT | VALUE_NUMBER_FLOAT, DecimalType()) => - // TODO: add fixed precision and scale handling - Decimal(parser.getDecimalValue) + case (VALUE_NUMBER_INT | VALUE_NUMBER_FLOAT, dt: DecimalType) => + Decimal(parser.getDecimalValue, dt.precision, dt.scale) case (VALUE_NUMBER_INT, ByteType) => parser.getByteValue http://git-wip-us.apache.org/repos/asf/spark/blob/864d5de6/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java ---------------------------------------------------------------------- diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java index cb84e78..e912eb8 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java @@ -164,7 +164,7 @@ public class JavaApplySchemaSuite implements Serializable { "\"bigInteger\":92233720368547758069, \"double\":1.7976931348623157E305, " + "\"boolean\":false, \"null\":null}")); List<StructField> fields = new ArrayList<StructField>(7); - fields.add(DataTypes.createStructField("bigInteger", DataTypes.createDecimalType(38, 18), + fields.add(DataTypes.createStructField("bigInteger", DataTypes.createDecimalType(20, 0), true)); fields.add(DataTypes.createStructField("boolean", DataTypes.BooleanType, true)); fields.add(DataTypes.createStructField("double", DataTypes.DoubleType, true)); http://git-wip-us.apache.org/repos/asf/spark/blob/864d5de6/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala index 037e204..9bca4e7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala @@ -148,7 +148,7 @@ class InMemoryColumnarQuerySuite extends QueryTest { val dataTypes = Seq(StringType, BinaryType, NullType, BooleanType, ByteType, ShortType, IntegerType, LongType, - FloatType, DoubleType, DecimalType.SYSTEM_DEFAULT, DecimalType(6, 5), + FloatType, DoubleType, DecimalType(25, 5), DecimalType(6, 5), DateType, TimestampType, ArrayType(IntegerType), MapType(StringType, LongType), struct) val fields = dataTypes.zipWithIndex.map { case (dataType, index) => http://git-wip-us.apache.org/repos/asf/spark/blob/864d5de6/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala index f19f22f..16a5c57 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala @@ -73,8 +73,6 @@ class JsonSuite extends QueryTest with TestJsonData { val doubleNumber: Double = 1.7976931348623157E308d checkTypePromotion(doubleNumber.toDouble, enforceCorrectType(doubleNumber, DoubleType)) - checkTypePromotion( - Decimal(doubleNumber), enforceCorrectType(doubleNumber, DecimalType.SYSTEM_DEFAULT)) checkTypePromotion(DateTimeUtils.fromJavaTimestamp(new Timestamp(intNumber)), enforceCorrectType(intNumber, TimestampType)) @@ -150,7 +148,7 @@ class JsonSuite extends QueryTest with TestJsonData { // DoubleType checkDataType(DoubleType, DoubleType, DoubleType) - checkDataType(DoubleType, DecimalType.SYSTEM_DEFAULT, DecimalType.SYSTEM_DEFAULT) + checkDataType(DoubleType, DecimalType.SYSTEM_DEFAULT, DoubleType) checkDataType(DoubleType, StringType, StringType) checkDataType(DoubleType, ArrayType(IntegerType), StringType) checkDataType(DoubleType, StructType(Nil), StringType) @@ -241,7 +239,7 @@ class JsonSuite extends QueryTest with TestJsonData { val jsonDF = ctx.read.json(primitiveFieldAndType) val expectedSchema = StructType( - StructField("bigInteger", DecimalType.SYSTEM_DEFAULT, true) :: + StructField("bigInteger", DecimalType(20, 0), true) :: StructField("boolean", BooleanType, true) :: StructField("double", DoubleType, true) :: StructField("integer", LongType, true) :: @@ -271,7 +269,7 @@ class JsonSuite extends QueryTest with TestJsonData { val expectedSchema = StructType( StructField("arrayOfArray1", ArrayType(ArrayType(StringType, true), true), true) :: StructField("arrayOfArray2", ArrayType(ArrayType(DoubleType, true), true), true) :: - StructField("arrayOfBigInteger", ArrayType(DecimalType.SYSTEM_DEFAULT, true), true) :: + StructField("arrayOfBigInteger", ArrayType(DecimalType(21, 0), true), true) :: StructField("arrayOfBoolean", ArrayType(BooleanType, true), true) :: StructField("arrayOfDouble", ArrayType(DoubleType, true), true) :: StructField("arrayOfInteger", ArrayType(LongType, true), true) :: @@ -285,7 +283,7 @@ class JsonSuite extends QueryTest with TestJsonData { StructField("field3", StringType, true) :: Nil), true), true) :: StructField("struct", StructType( StructField("field1", BooleanType, true) :: - StructField("field2", DecimalType.SYSTEM_DEFAULT, true) :: Nil), true) :: + StructField("field2", DecimalType(20, 0), true) :: Nil), true) :: StructField("structWithArrayFields", StructType( StructField("field1", ArrayType(LongType, true), true) :: StructField("field2", ArrayType(StringType, true), true) :: Nil), true) :: Nil) @@ -386,7 +384,7 @@ class JsonSuite extends QueryTest with TestJsonData { val expectedSchema = StructType( StructField("num_bool", StringType, true) :: StructField("num_num_1", LongType, true) :: - StructField("num_num_2", DecimalType.SYSTEM_DEFAULT, true) :: + StructField("num_num_2", DoubleType, true) :: StructField("num_num_3", DoubleType, true) :: StructField("num_str", StringType, true) :: StructField("str_bool", StringType, true) :: Nil) @@ -398,11 +396,9 @@ class JsonSuite extends QueryTest with TestJsonData { checkAnswer( sql("select * from jsonTable"), Row("true", 11L, null, 1.1, "13.1", "str1") :: - Row("12", null, new java.math.BigDecimal("21474836470.9"), null, null, "true") :: - Row("false", 21474836470L, - new java.math.BigDecimal("92233720368547758070"), 100, "str1", "false") :: - Row(null, 21474836570L, - new java.math.BigDecimal("1.1"), 21474836470L, "92233720368547758070", null) :: Nil + Row("12", null, 21474836470.9, null, null, "true") :: + Row("false", 21474836470L, 92233720368547758070d, 100, "str1", "false") :: + Row(null, 21474836570L, 1.1, 21474836470L, "92233720368547758070", null) :: Nil ) // Number and Boolean conflict: resolve the type as number in this query. @@ -425,8 +421,8 @@ class JsonSuite extends QueryTest with TestJsonData { // Widening to DecimalType checkAnswer( sql("select num_num_2 + 1.3 from jsonTable where num_num_2 > 1.1"), - Row(BigDecimal("21474836472.2")) :: - Row(BigDecimal("92233720368547758071.3")) :: Nil + Row(21474836472.2) :: + Row(92233720368547758071.3) :: Nil ) // Widening to Double @@ -611,7 +607,7 @@ class JsonSuite extends QueryTest with TestJsonData { val jsonDF = ctx.read.json(path) val expectedSchema = StructType( - StructField("bigInteger", DecimalType.SYSTEM_DEFAULT, true) :: + StructField("bigInteger", DecimalType(20, 0), true) :: StructField("boolean", BooleanType, true) :: StructField("double", DoubleType, true) :: StructField("integer", LongType, true) :: http://git-wip-us.apache.org/repos/asf/spark/blob/864d5de6/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala index a95f70f..5c65a8e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala @@ -189,4 +189,17 @@ class ParquetQuerySuite extends QueryTest with ParquetTest { } } } + + test("SPARK-9119 Decimal should be correctly written into parquet") { + withTempPath { dir => + val basePath = dir.getCanonicalPath + val schema = StructType(Array(StructField("name", DecimalType(10, 5), false))) + val rowRDD = sqlContext.sparkContext.parallelize(Array(Row(Decimal("67123.45")))) + val df = sqlContext.createDataFrame(rowRDD, schema) + df.write.parquet(basePath) + + val decimal = sqlContext.read.parquet(basePath).first().getDecimal(0) + assert(Decimal("67123.45") === Decimal(decimal)) + } + } } http://git-wip-us.apache.org/repos/asf/spark/blob/864d5de6/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala ---------------------------------------------------------------------- diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala index 97e4ea2..a6a343d 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala @@ -29,7 +29,6 @@ import org.apache.hadoop.hive.serde2.AbstractSerDe import org.apache.hadoop.hive.serde2.objectinspector._ import org.apache.hadoop.io.Writable -import org.apache.spark.{TaskContext, Logging} import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.expressions._ @@ -39,6 +38,7 @@ import org.apache.spark.sql.hive.HiveShim._ import org.apache.spark.sql.hive.{HiveContext, HiveInspectors} import org.apache.spark.sql.types.DataType import org.apache.spark.util.{CircularBuffer, RedirectThread, Utils} +import org.apache.spark.{Logging, TaskContext} /** * Transforms the input by forking and running the specified script. --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org