http://git-wip-us.apache.org/repos/asf/spark/blob/f9969098/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/decimal/Decimal.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/decimal/Decimal.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/decimal/Decimal.scala deleted file mode 100644 index 708362a..0000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/decimal/Decimal.scala +++ /dev/null @@ -1,335 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.types.decimal - -import org.apache.spark.annotation.DeveloperApi - -/** - * A mutable implementation of BigDecimal that can hold a Long if values are small enough. - * - * The semantics of the fields are as follows: - * - _precision and _scale represent the SQL precision and scale we are looking for - * - If decimalVal is set, it represents the whole decimal value - * - Otherwise, the decimal value is longVal / (10 ** _scale) - */ -final class Decimal extends Ordered[Decimal] with Serializable { - import Decimal.{MAX_LONG_DIGITS, POW_10, ROUNDING_MODE, BIG_DEC_ZERO} - - private var decimalVal: BigDecimal = null - private var longVal: Long = 0L - private var _precision: Int = 1 - private var _scale: Int = 0 - - def precision: Int = _precision - def scale: Int = _scale - - /** - * Set this Decimal to the given Long. Will have precision 20 and scale 0. - */ - def set(longVal: Long): Decimal = { - if (longVal <= -POW_10(MAX_LONG_DIGITS) || longVal >= POW_10(MAX_LONG_DIGITS)) { - // We can't represent this compactly as a long without risking overflow - this.decimalVal = BigDecimal(longVal) - this.longVal = 0L - } else { - this.decimalVal = null - this.longVal = longVal - } - this._precision = 20 - this._scale = 0 - this - } - - /** - * Set this Decimal to the given Int. Will have precision 10 and scale 0. - */ - def set(intVal: Int): Decimal = { - this.decimalVal = null - this.longVal = intVal - this._precision = 10 - this._scale = 0 - this - } - - /** - * Set this Decimal to the given unscaled Long, with a given precision and scale. - */ - def set(unscaled: Long, precision: Int, scale: Int): Decimal = { - if (setOrNull(unscaled, precision, scale) == null) { - throw new IllegalArgumentException("Unscaled value too large for precision") - } - this - } - - /** - * Set this Decimal to the given unscaled Long, with a given precision and scale, - * and return it, or return null if it cannot be set due to overflow. - */ - def setOrNull(unscaled: Long, precision: Int, scale: Int): Decimal = { - if (unscaled <= -POW_10(MAX_LONG_DIGITS) || unscaled >= POW_10(MAX_LONG_DIGITS)) { - // We can't represent this compactly as a long without risking overflow - if (precision < 19) { - return null // Requested precision is too low to represent this value - } - this.decimalVal = BigDecimal(longVal) - this.longVal = 0L - } else { - val p = POW_10(math.min(precision, MAX_LONG_DIGITS)) - if (unscaled <= -p || unscaled >= p) { - return null // Requested precision is too low to represent this value - } - this.decimalVal = null - this.longVal = unscaled - } - this._precision = precision - this._scale = scale - this - } - - /** - * Set this Decimal to the given BigDecimal value, with a given precision and scale. - */ - def set(decimal: BigDecimal, precision: Int, scale: Int): Decimal = { - this.decimalVal = decimal.setScale(scale, ROUNDING_MODE) - require(decimalVal.precision <= precision, "Overflowed precision") - this.longVal = 0L - this._precision = precision - this._scale = scale - this - } - - /** - * Set this Decimal to the given BigDecimal value, inheriting its precision and scale. - */ - def set(decimal: BigDecimal): Decimal = { - this.decimalVal = decimal - this.longVal = 0L - this._precision = decimal.precision - this._scale = decimal.scale - this - } - - /** - * Set this Decimal to the given Decimal value. - */ - def set(decimal: Decimal): Decimal = { - this.decimalVal = decimal.decimalVal - this.longVal = decimal.longVal - this._precision = decimal._precision - this._scale = decimal._scale - this - } - - def toBigDecimal: BigDecimal = { - if (decimalVal.ne(null)) { - decimalVal - } else { - BigDecimal(longVal, _scale) - } - } - - def toUnscaledLong: Long = { - if (decimalVal.ne(null)) { - decimalVal.underlying().unscaledValue().longValue() - } else { - longVal - } - } - - override def toString: String = toBigDecimal.toString() - - @DeveloperApi - def toDebugString: String = { - if (decimalVal.ne(null)) { - s"Decimal(expanded,$decimalVal,$precision,$scale})" - } else { - s"Decimal(compact,$longVal,$precision,$scale})" - } - } - - def toDouble: Double = toBigDecimal.doubleValue() - - def toFloat: Float = toBigDecimal.floatValue() - - def toLong: Long = { - if (decimalVal.eq(null)) { - longVal / POW_10(_scale) - } else { - decimalVal.longValue() - } - } - - def toInt: Int = toLong.toInt - - def toShort: Short = toLong.toShort - - def toByte: Byte = toLong.toByte - - /** - * Update precision and scale while keeping our value the same, and return true if successful. - * - * @return true if successful, false if overflow would occur - */ - def changePrecision(precision: Int, scale: Int): Boolean = { - // First, update our longVal if we can, or transfer over to using a BigDecimal - if (decimalVal.eq(null)) { - if (scale < _scale) { - // Easier case: we just need to divide our scale down - val diff = _scale - scale - val droppedDigits = longVal % POW_10(diff) - longVal /= POW_10(diff) - if (math.abs(droppedDigits) * 2 >= POW_10(diff)) { - longVal += (if (longVal < 0) -1L else 1L) - } - } else if (scale > _scale) { - // We might be able to multiply longVal by a power of 10 and not overflow, but if not, - // switch to using a BigDecimal - val diff = scale - _scale - val p = POW_10(math.max(MAX_LONG_DIGITS - diff, 0)) - if (diff <= MAX_LONG_DIGITS && longVal > -p && longVal < p) { - // Multiplying longVal by POW_10(diff) will still keep it below MAX_LONG_DIGITS - longVal *= POW_10(diff) - } else { - // Give up on using Longs; switch to BigDecimal, which we'll modify below - decimalVal = BigDecimal(longVal, _scale) - } - } - // In both cases, we will check whether our precision is okay below - } - - if (decimalVal.ne(null)) { - // We get here if either we started with a BigDecimal, or we switched to one because we would - // have overflowed our Long; in either case we must rescale decimalVal to the new scale. - val newVal = decimalVal.setScale(scale, ROUNDING_MODE) - if (newVal.precision > precision) { - return false - } - decimalVal = newVal - } else { - // We're still using Longs, but we should check whether we match the new precision - val p = POW_10(math.min(_precision, MAX_LONG_DIGITS)) - if (longVal <= -p || longVal >= p) { - // Note that we shouldn't have been able to fix this by switching to BigDecimal - return false - } - } - - _precision = precision - _scale = scale - true - } - - override def clone(): Decimal = new Decimal().set(this) - - override def compare(other: Decimal): Int = { - if (decimalVal.eq(null) && other.decimalVal.eq(null) && _scale == other._scale) { - if (longVal < other.longVal) -1 else if (longVal == other.longVal) 0 else 1 - } else { - toBigDecimal.compare(other.toBigDecimal) - } - } - - override def equals(other: Any) = other match { - case d: Decimal => - compare(d) == 0 - case _ => - false - } - - override def hashCode(): Int = toBigDecimal.hashCode() - - def isZero: Boolean = if (decimalVal.ne(null)) decimalVal == BIG_DEC_ZERO else longVal == 0 - - def + (that: Decimal): Decimal = Decimal(toBigDecimal + that.toBigDecimal) - - def - (that: Decimal): Decimal = Decimal(toBigDecimal - that.toBigDecimal) - - def * (that: Decimal): Decimal = Decimal(toBigDecimal * that.toBigDecimal) - - def / (that: Decimal): Decimal = - if (that.isZero) null else Decimal(toBigDecimal / that.toBigDecimal) - - def % (that: Decimal): Decimal = - if (that.isZero) null else Decimal(toBigDecimal % that.toBigDecimal) - - def remainder(that: Decimal): Decimal = this % that - - def unary_- : Decimal = { - if (decimalVal.ne(null)) { - Decimal(-decimalVal) - } else { - Decimal(-longVal, precision, scale) - } - } -} - -object Decimal { - private val ROUNDING_MODE = BigDecimal.RoundingMode.HALF_UP - - /** Maximum number of decimal digits a Long can represent */ - val MAX_LONG_DIGITS = 18 - - private val POW_10 = Array.tabulate[Long](MAX_LONG_DIGITS + 1)(i => math.pow(10, i).toLong) - - private val BIG_DEC_ZERO = BigDecimal(0) - - def apply(value: Double): Decimal = new Decimal().set(value) - - def apply(value: Long): Decimal = new Decimal().set(value) - - def apply(value: Int): Decimal = new Decimal().set(value) - - def apply(value: BigDecimal): Decimal = new Decimal().set(value) - - def apply(value: BigDecimal, precision: Int, scale: Int): Decimal = - new Decimal().set(value, precision, scale) - - def apply(unscaled: Long, precision: Int, scale: Int): Decimal = - new Decimal().set(unscaled, precision, scale) - - def apply(value: String): Decimal = new Decimal().set(BigDecimal(value)) - - // Evidence parameters for Decimal considered either as Fractional or Integral. We provide two - // parameters inheriting from a common trait since both traits define mkNumericOps. - // See scala.math's Numeric.scala for examples for Scala's built-in types. - - /** Common methods for Decimal evidence parameters */ - trait DecimalIsConflicted extends Numeric[Decimal] { - override def plus(x: Decimal, y: Decimal): Decimal = x + y - override def times(x: Decimal, y: Decimal): Decimal = x * y - override def minus(x: Decimal, y: Decimal): Decimal = x - y - override def negate(x: Decimal): Decimal = -x - override def toDouble(x: Decimal): Double = x.toDouble - override def toFloat(x: Decimal): Float = x.toFloat - override def toInt(x: Decimal): Int = x.toInt - override def toLong(x: Decimal): Long = x.toLong - override def fromInt(x: Int): Decimal = new Decimal().set(x) - override def compare(x: Decimal, y: Decimal): Int = x.compare(y) - } - - /** A [[scala.math.Fractional]] evidence parameter for Decimals. */ - object DecimalIsFractional extends DecimalIsConflicted with Fractional[Decimal] { - override def div(x: Decimal, y: Decimal): Decimal = x / y - } - - /** A [[scala.math.Integral]] evidence parameter for Decimals. */ - object DecimalAsIfIntegral extends DecimalIsConflicted with Integral[Decimal] { - override def quot(x: Decimal, y: Decimal): Decimal = x / y - override def rem(x: Decimal, y: Decimal): Decimal = x % y - } -}
http://git-wip-us.apache.org/repos/asf/spark/blob/f9969098/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/package.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/package.scala deleted file mode 100644 index de24449..0000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/package.scala +++ /dev/null @@ -1,23 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst -/** - * Contains a type system for attributes produced by relations, including complex types like - * structs, arrays and maps. - */ -package object types http://git-wip-us.apache.org/repos/asf/spark/blob/f9969098/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/Metadata.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/Metadata.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/Metadata.scala deleted file mode 100755 index 8172733..0000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/Metadata.scala +++ /dev/null @@ -1,258 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.util - -import scala.collection.mutable - -import org.json4s._ -import org.json4s.jackson.JsonMethods._ - -/** - * Metadata is a wrapper over Map[String, Any] that limits the value type to simple ones: Boolean, - * Long, Double, String, Metadata, Array[Boolean], Array[Long], Array[Double], Array[String], and - * Array[Metadata]. JSON is used for serialization. - * - * The default constructor is private. User should use either [[MetadataBuilder]] or - * [[Metadata$#fromJson]] to create Metadata instances. - * - * @param map an immutable map that stores the data - */ -sealed class Metadata private[util] (private[util] val map: Map[String, Any]) extends Serializable { - - /** Tests whether this Metadata contains a binding for a key. */ - def contains(key: String): Boolean = map.contains(key) - - /** Gets a Long. */ - def getLong(key: String): Long = get(key) - - /** Gets a Double. */ - def getDouble(key: String): Double = get(key) - - /** Gets a Boolean. */ - def getBoolean(key: String): Boolean = get(key) - - /** Gets a String. */ - def getString(key: String): String = get(key) - - /** Gets a Metadata. */ - def getMetadata(key: String): Metadata = get(key) - - /** Gets a Long array. */ - def getLongArray(key: String): Array[Long] = get(key) - - /** Gets a Double array. */ - def getDoubleArray(key: String): Array[Double] = get(key) - - /** Gets a Boolean array. */ - def getBooleanArray(key: String): Array[Boolean] = get(key) - - /** Gets a String array. */ - def getStringArray(key: String): Array[String] = get(key) - - /** Gets a Metadata array. */ - def getMetadataArray(key: String): Array[Metadata] = get(key) - - /** Converts to its JSON representation. */ - def json: String = compact(render(jsonValue)) - - override def toString: String = json - - override def equals(obj: Any): Boolean = { - obj match { - case that: Metadata => - if (map.keySet == that.map.keySet) { - map.keys.forall { k => - (map(k), that.map(k)) match { - case (v0: Array[_], v1: Array[_]) => - v0.view == v1.view - case (v0, v1) => - v0 == v1 - } - } - } else { - false - } - case other => - false - } - } - - override def hashCode: Int = Metadata.hash(this) - - private def get[T](key: String): T = { - map(key).asInstanceOf[T] - } - - private[sql] def jsonValue: JValue = Metadata.toJsonValue(this) -} - -object Metadata { - - /** Returns an empty Metadata. */ - def empty: Metadata = new Metadata(Map.empty) - - /** Creates a Metadata instance from JSON. */ - def fromJson(json: String): Metadata = { - fromJObject(parse(json).asInstanceOf[JObject]) - } - - /** Creates a Metadata instance from JSON AST. */ - private[sql] def fromJObject(jObj: JObject): Metadata = { - val builder = new MetadataBuilder - jObj.obj.foreach { - case (key, JInt(value)) => - builder.putLong(key, value.toLong) - case (key, JDouble(value)) => - builder.putDouble(key, value) - case (key, JBool(value)) => - builder.putBoolean(key, value) - case (key, JString(value)) => - builder.putString(key, value) - case (key, o: JObject) => - builder.putMetadata(key, fromJObject(o)) - case (key, JArray(value)) => - if (value.isEmpty) { - // If it is an empty array, we cannot infer its element type. We put an empty Array[Long]. - builder.putLongArray(key, Array.empty) - } else { - value.head match { - case _: JInt => - builder.putLongArray(key, value.asInstanceOf[List[JInt]].map(_.num.toLong).toArray) - case _: JDouble => - builder.putDoubleArray(key, value.asInstanceOf[List[JDouble]].map(_.num).toArray) - case _: JBool => - builder.putBooleanArray(key, value.asInstanceOf[List[JBool]].map(_.value).toArray) - case _: JString => - builder.putStringArray(key, value.asInstanceOf[List[JString]].map(_.s).toArray) - case _: JObject => - builder.putMetadataArray( - key, value.asInstanceOf[List[JObject]].map(fromJObject).toArray) - case other => - throw new RuntimeException(s"Do not support array of type ${other.getClass}.") - } - } - case other => - throw new RuntimeException(s"Do not support type ${other.getClass}.") - } - builder.build() - } - - /** Converts to JSON AST. */ - private def toJsonValue(obj: Any): JValue = { - obj match { - case map: Map[_, _] => - val fields = map.toList.map { case (k: String, v) => (k, toJsonValue(v)) } - JObject(fields) - case arr: Array[_] => - val values = arr.toList.map(toJsonValue) - JArray(values) - case x: Long => - JInt(x) - case x: Double => - JDouble(x) - case x: Boolean => - JBool(x) - case x: String => - JString(x) - case x: Metadata => - toJsonValue(x.map) - case other => - throw new RuntimeException(s"Do not support type ${other.getClass}.") - } - } - - /** Computes the hash code for the types we support. */ - private def hash(obj: Any): Int = { - obj match { - case map: Map[_, _] => - map.mapValues(hash).## - case arr: Array[_] => - // Seq.empty[T] has the same hashCode regardless of T. - arr.toSeq.map(hash).## - case x: Long => - x.## - case x: Double => - x.## - case x: Boolean => - x.## - case x: String => - x.## - case x: Metadata => - hash(x.map) - case other => - throw new RuntimeException(s"Do not support type ${other.getClass}.") - } - } -} - -/** - * Builder for [[Metadata]]. If there is a key collision, the latter will overwrite the former. - */ -class MetadataBuilder { - - private val map: mutable.Map[String, Any] = mutable.Map.empty - - /** Returns the immutable version of this map. Used for java interop. */ - protected def getMap = map.toMap - - /** Include the content of an existing [[Metadata]] instance. */ - def withMetadata(metadata: Metadata): this.type = { - map ++= metadata.map - this - } - - /** Puts a Long. */ - def putLong(key: String, value: Long): this.type = put(key, value) - - /** Puts a Double. */ - def putDouble(key: String, value: Double): this.type = put(key, value) - - /** Puts a Boolean. */ - def putBoolean(key: String, value: Boolean): this.type = put(key, value) - - /** Puts a String. */ - def putString(key: String, value: String): this.type = put(key, value) - - /** Puts a [[Metadata]]. */ - def putMetadata(key: String, value: Metadata): this.type = put(key, value) - - /** Puts a Long array. */ - def putLongArray(key: String, value: Array[Long]): this.type = put(key, value) - - /** Puts a Double array. */ - def putDoubleArray(key: String, value: Array[Double]): this.type = put(key, value) - - /** Puts a Boolean array. */ - def putBooleanArray(key: String, value: Array[Boolean]): this.type = put(key, value) - - /** Puts a String array. */ - def putStringArray(key: String, value: Array[String]): this.type = put(key, value) - - /** Puts a [[Metadata]] array. */ - def putMetadataArray(key: String, value: Array[Metadata]): this.type = put(key, value) - - /** Builds the [[Metadata]] instance. */ - def build(): Metadata = { - new Metadata(map.toMap) - } - - private def put(key: String, value: Any): this.type = { - map.put(key, value) - this - } -} http://git-wip-us.apache.org/repos/asf/spark/blob/f9969098/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataTypeConversions.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataTypeConversions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataTypeConversions.scala new file mode 100644 index 0000000..2a8914c --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataTypeConversions.scala @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.types + +import java.text.SimpleDateFormat + +import org.apache.spark.sql.catalyst.ScalaReflection +import org.apache.spark.sql.types.decimal.Decimal + + +protected[sql] object DataTypeConversions { + + def stringToTime(s: String): java.util.Date = { + if (!s.contains('T')) { + // JDBC escape string + if (s.contains(' ')) { + java.sql.Timestamp.valueOf(s) + } else { + java.sql.Date.valueOf(s) + } + } else if (s.endsWith("Z")) { + // this is zero timezone of ISO8601 + stringToTime(s.substring(0, s.length - 1) + "GMT-00:00") + } else if (s.indexOf("GMT") == -1) { + // timezone with ISO8601 + val inset = "+00.00".length + val s0 = s.substring(0, s.length - inset) + val s1 = s.substring(s.length - inset, s.length) + if (s0.substring(s0.lastIndexOf(':')).contains('.')) { + stringToTime(s0 + "GMT" + s1) + } else { + stringToTime(s0 + ".0GMT" + s1) + } + } else { + // ISO8601 with GMT insert + val ISO8601GMT: SimpleDateFormat = new SimpleDateFormat( "yyyy-MM-dd'T'HH:mm:ss.SSSz" ) + ISO8601GMT.parse(s) + } + } + + /** Converts Java objects to catalyst rows / types */ + def convertJavaToCatalyst(a: Any, dataType: DataType): Any = (a, dataType) match { + case (obj, udt: UserDefinedType[_]) => ScalaReflection.convertToCatalyst(obj, udt) // Scala type + case (d: java.math.BigDecimal, _) => Decimal(BigDecimal(d)) + case (other, _) => other + } + + /** Converts Java objects to catalyst rows / types */ + def convertCatalystToJava(a: Any): Any = a match { + case d: scala.math.BigDecimal => d.underlying() + case other => other + } +} http://git-wip-us.apache.org/repos/asf/spark/blob/f9969098/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataTypes.java ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataTypes.java b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataTypes.java new file mode 100644 index 0000000..e457542 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataTypes.java @@ -0,0 +1,208 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.types; + +import java.util.*; + +/** + * To get/create specific data type, users should use singleton objects and factory methods + * provided by this class. + */ +public class DataTypes { + /** + * Gets the StringType object. + */ + public static final DataType StringType = StringType$.MODULE$; + + /** + * Gets the BinaryType object. + */ + public static final DataType BinaryType = BinaryType$.MODULE$; + + /** + * Gets the BooleanType object. + */ + public static final DataType BooleanType = BooleanType$.MODULE$; + + /** + * Gets the DateType object. + */ + public static final DataType DateType = DateType$.MODULE$; + + /** + * Gets the TimestampType object. + */ + public static final DataType TimestampType = TimestampType$.MODULE$; + + /** + * Gets the DoubleType object. + */ + public static final DataType DoubleType = DoubleType$.MODULE$; + + /** + * Gets the FloatType object. + */ + public static final DataType FloatType = FloatType$.MODULE$; + + /** + * Gets the ByteType object. + */ + public static final DataType ByteType = ByteType$.MODULE$; + + /** + * Gets the IntegerType object. + */ + public static final DataType IntegerType = IntegerType$.MODULE$; + + /** + * Gets the LongType object. + */ + public static final DataType LongType = LongType$.MODULE$; + + /** + * Gets the ShortType object. + */ + public static final DataType ShortType = ShortType$.MODULE$; + + /** + * Gets the NullType object. + */ + public static final DataType NullType = NullType$.MODULE$; + + /** + * Creates an ArrayType by specifying the data type of elements ({@code elementType}). + * The field of {@code containsNull} is set to {@code true}. + */ + public static ArrayType createArrayType(DataType elementType) { + if (elementType == null) { + throw new IllegalArgumentException("elementType should not be null."); + } + return new ArrayType(elementType, true); + } + + /** + * Creates an ArrayType by specifying the data type of elements ({@code elementType}) and + * whether the array contains null values ({@code containsNull}). + */ + public static ArrayType createArrayType(DataType elementType, boolean containsNull) { + if (elementType == null) { + throw new IllegalArgumentException("elementType should not be null."); + } + return new ArrayType(elementType, containsNull); + } + + public static DecimalType createDecimalType(int precision, int scale) { + return DecimalType$.MODULE$.apply(precision, scale); + } + + public static DecimalType createDecimalType() { + return DecimalType$.MODULE$.Unlimited(); + } + + /** + * Creates a MapType by specifying the data type of keys ({@code keyType}) and values + * ({@code keyType}). The field of {@code valueContainsNull} is set to {@code true}. + */ + public static MapType createMapType(DataType keyType, DataType valueType) { + if (keyType == null) { + throw new IllegalArgumentException("keyType should not be null."); + } + if (valueType == null) { + throw new IllegalArgumentException("valueType should not be null."); + } + return new MapType(keyType, valueType, true); + } + + /** + * Creates a MapType by specifying the data type of keys ({@code keyType}), the data type of + * values ({@code keyType}), and whether values contain any null value + * ({@code valueContainsNull}). + */ + public static MapType createMapType( + DataType keyType, + DataType valueType, + boolean valueContainsNull) { + if (keyType == null) { + throw new IllegalArgumentException("keyType should not be null."); + } + if (valueType == null) { + throw new IllegalArgumentException("valueType should not be null."); + } + return new MapType(keyType, valueType, valueContainsNull); + } + + /** + * Creates a StructField by specifying the name ({@code name}), data type ({@code dataType}) and + * whether values of this field can be null values ({@code nullable}). + */ + public static StructField createStructField( + String name, + DataType dataType, + boolean nullable, + Metadata metadata) { + if (name == null) { + throw new IllegalArgumentException("name should not be null."); + } + if (dataType == null) { + throw new IllegalArgumentException("dataType should not be null."); + } + if (metadata == null) { + throw new IllegalArgumentException("metadata should not be null."); + } + return new StructField(name, dataType, nullable, metadata); + } + + /** + * Creates a StructField with empty metadata. + * + * @see #createStructField(String, DataType, boolean, Metadata) + */ + public static StructField createStructField(String name, DataType dataType, boolean nullable) { + return createStructField(name, dataType, nullable, (new MetadataBuilder()).build()); + } + + /** + * Creates a StructType with the given list of StructFields ({@code fields}). + */ + public static StructType createStructType(List<StructField> fields) { + return createStructType(fields.toArray(new StructField[0])); + } + + /** + * Creates a StructType with the given StructField array ({@code fields}). + */ + public static StructType createStructType(StructField[] fields) { + if (fields == null) { + throw new IllegalArgumentException("fields should not be null."); + } + Set<String> distinctNames = new HashSet<String>(); + for (StructField field : fields) { + if (field == null) { + throw new IllegalArgumentException( + "fields should not contain any null."); + } + + distinctNames.add(field.name()); + } + if (distinctNames.size() != fields.length) { + throw new IllegalArgumentException("fields should have distinct names."); + } + + return StructType$.MODULE$.apply(fields); + } +} http://git-wip-us.apache.org/repos/asf/spark/blob/f9969098/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Metadata.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Metadata.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Metadata.scala new file mode 100755 index 0000000..e50e976 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Metadata.scala @@ -0,0 +1,268 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.types + +import scala.collection.mutable + +import org.json4s._ +import org.json4s.jackson.JsonMethods._ + +import org.apache.spark.annotation.DeveloperApi + + +/** + * :: DeveloperApi :: + * + * Metadata is a wrapper over Map[String, Any] that limits the value type to simple ones: Boolean, + * Long, Double, String, Metadata, Array[Boolean], Array[Long], Array[Double], Array[String], and + * Array[Metadata]. JSON is used for serialization. + * + * The default constructor is private. User should use either [[MetadataBuilder]] or + * [[Metadata.fromJson()]] to create Metadata instances. + * + * @param map an immutable map that stores the data + */ +@DeveloperApi +sealed class Metadata private[types] (private[types] val map: Map[String, Any]) + extends Serializable { + + /** Tests whether this Metadata contains a binding for a key. */ + def contains(key: String): Boolean = map.contains(key) + + /** Gets a Long. */ + def getLong(key: String): Long = get(key) + + /** Gets a Double. */ + def getDouble(key: String): Double = get(key) + + /** Gets a Boolean. */ + def getBoolean(key: String): Boolean = get(key) + + /** Gets a String. */ + def getString(key: String): String = get(key) + + /** Gets a Metadata. */ + def getMetadata(key: String): Metadata = get(key) + + /** Gets a Long array. */ + def getLongArray(key: String): Array[Long] = get(key) + + /** Gets a Double array. */ + def getDoubleArray(key: String): Array[Double] = get(key) + + /** Gets a Boolean array. */ + def getBooleanArray(key: String): Array[Boolean] = get(key) + + /** Gets a String array. */ + def getStringArray(key: String): Array[String] = get(key) + + /** Gets a Metadata array. */ + def getMetadataArray(key: String): Array[Metadata] = get(key) + + /** Converts to its JSON representation. */ + def json: String = compact(render(jsonValue)) + + override def toString: String = json + + override def equals(obj: Any): Boolean = { + obj match { + case that: Metadata => + if (map.keySet == that.map.keySet) { + map.keys.forall { k => + (map(k), that.map(k)) match { + case (v0: Array[_], v1: Array[_]) => + v0.view == v1.view + case (v0, v1) => + v0 == v1 + } + } + } else { + false + } + case other => + false + } + } + + override def hashCode: Int = Metadata.hash(this) + + private def get[T](key: String): T = { + map(key).asInstanceOf[T] + } + + private[sql] def jsonValue: JValue = Metadata.toJsonValue(this) +} + +object Metadata { + + /** Returns an empty Metadata. */ + def empty: Metadata = new Metadata(Map.empty) + + /** Creates a Metadata instance from JSON. */ + def fromJson(json: String): Metadata = { + fromJObject(parse(json).asInstanceOf[JObject]) + } + + /** Creates a Metadata instance from JSON AST. */ + private[sql] def fromJObject(jObj: JObject): Metadata = { + val builder = new MetadataBuilder + jObj.obj.foreach { + case (key, JInt(value)) => + builder.putLong(key, value.toLong) + case (key, JDouble(value)) => + builder.putDouble(key, value) + case (key, JBool(value)) => + builder.putBoolean(key, value) + case (key, JString(value)) => + builder.putString(key, value) + case (key, o: JObject) => + builder.putMetadata(key, fromJObject(o)) + case (key, JArray(value)) => + if (value.isEmpty) { + // If it is an empty array, we cannot infer its element type. We put an empty Array[Long]. + builder.putLongArray(key, Array.empty) + } else { + value.head match { + case _: JInt => + builder.putLongArray(key, value.asInstanceOf[List[JInt]].map(_.num.toLong).toArray) + case _: JDouble => + builder.putDoubleArray(key, value.asInstanceOf[List[JDouble]].map(_.num).toArray) + case _: JBool => + builder.putBooleanArray(key, value.asInstanceOf[List[JBool]].map(_.value).toArray) + case _: JString => + builder.putStringArray(key, value.asInstanceOf[List[JString]].map(_.s).toArray) + case _: JObject => + builder.putMetadataArray( + key, value.asInstanceOf[List[JObject]].map(fromJObject).toArray) + case other => + throw new RuntimeException(s"Do not support array of type ${other.getClass}.") + } + } + case other => + throw new RuntimeException(s"Do not support type ${other.getClass}.") + } + builder.build() + } + + /** Converts to JSON AST. */ + private def toJsonValue(obj: Any): JValue = { + obj match { + case map: Map[_, _] => + val fields = map.toList.map { case (k: String, v) => (k, toJsonValue(v)) } + JObject(fields) + case arr: Array[_] => + val values = arr.toList.map(toJsonValue) + JArray(values) + case x: Long => + JInt(x) + case x: Double => + JDouble(x) + case x: Boolean => + JBool(x) + case x: String => + JString(x) + case x: Metadata => + toJsonValue(x.map) + case other => + throw new RuntimeException(s"Do not support type ${other.getClass}.") + } + } + + /** Computes the hash code for the types we support. */ + private def hash(obj: Any): Int = { + obj match { + case map: Map[_, _] => + map.mapValues(hash).## + case arr: Array[_] => + // Seq.empty[T] has the same hashCode regardless of T. + arr.toSeq.map(hash).## + case x: Long => + x.## + case x: Double => + x.## + case x: Boolean => + x.## + case x: String => + x.## + case x: Metadata => + hash(x.map) + case other => + throw new RuntimeException(s"Do not support type ${other.getClass}.") + } + } +} + +/** + * :: DeveloperApi :: + * + * Builder for [[Metadata]]. If there is a key collision, the latter will overwrite the former. + */ +@DeveloperApi +class MetadataBuilder { + + private val map: mutable.Map[String, Any] = mutable.Map.empty + + /** Returns the immutable version of this map. Used for java interop. */ + protected def getMap = map.toMap + + /** Include the content of an existing [[Metadata]] instance. */ + def withMetadata(metadata: Metadata): this.type = { + map ++= metadata.map + this + } + + /** Puts a Long. */ + def putLong(key: String, value: Long): this.type = put(key, value) + + /** Puts a Double. */ + def putDouble(key: String, value: Double): this.type = put(key, value) + + /** Puts a Boolean. */ + def putBoolean(key: String, value: Boolean): this.type = put(key, value) + + /** Puts a String. */ + def putString(key: String, value: String): this.type = put(key, value) + + /** Puts a [[Metadata]]. */ + def putMetadata(key: String, value: Metadata): this.type = put(key, value) + + /** Puts a Long array. */ + def putLongArray(key: String, value: Array[Long]): this.type = put(key, value) + + /** Puts a Double array. */ + def putDoubleArray(key: String, value: Array[Double]): this.type = put(key, value) + + /** Puts a Boolean array. */ + def putBooleanArray(key: String, value: Array[Boolean]): this.type = put(key, value) + + /** Puts a String array. */ + def putStringArray(key: String, value: Array[String]): this.type = put(key, value) + + /** Puts a [[Metadata]] array. */ + def putMetadataArray(key: String, value: Array[Metadata]): this.type = put(key, value) + + /** Builds the [[Metadata]] instance. */ + def build(): Metadata = { + new Metadata(map.toMap) + } + + private def put(key: String, value: Any): this.type = { + map.put(key, value) + this + } +} http://git-wip-us.apache.org/repos/asf/spark/blob/f9969098/sql/catalyst/src/main/scala/org/apache/spark/sql/types/SQLUserDefinedType.java ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/SQLUserDefinedType.java b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/SQLUserDefinedType.java new file mode 100644 index 0000000..a64d2bb --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/SQLUserDefinedType.java @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.types; + +import java.lang.annotation.*; + +import org.apache.spark.annotation.DeveloperApi; + +/** + * ::DeveloperApi:: + * A user-defined type which can be automatically recognized by a SQLContext and registered. + * + * WARNING: This annotation will only work if both Java and Scala reflection return the same class + * names (after erasure) for the UDT. This will NOT be the case when, e.g., the UDT class + * is enclosed in an object (a singleton). + * + * WARNING: UDTs are currently only supported from Scala. + */ +// TODO: Should I used @Documented ? +@DeveloperApi +@Retention(RetentionPolicy.RUNTIME) +@Target(ElementType.TYPE) +public @interface SQLUserDefinedType { + + /** + * Returns an instance of the UserDefinedType which can serialize and deserialize the user + * class to and from Catalyst built-in types. + */ + Class<? extends UserDefinedType<?> > udt(); +} http://git-wip-us.apache.org/repos/asf/spark/blob/f9969098/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala new file mode 100644 index 0000000..fa0a355 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala @@ -0,0 +1,900 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.types + +import java.sql.{Date, Timestamp} + +import scala.math.Numeric.{FloatAsIfIntegral, DoubleAsIfIntegral} +import scala.reflect.ClassTag +import scala.reflect.runtime.universe.{TypeTag, runtimeMirror, typeTag} +import scala.util.parsing.combinator.RegexParsers + +import org.json4s._ +import org.json4s.JsonAST.JValue +import org.json4s.JsonDSL._ +import org.json4s.jackson.JsonMethods._ + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.sql.catalyst.ScalaReflectionLock +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Expression} +import org.apache.spark.sql.types.decimal._ +import org.apache.spark.util.Utils + + +object DataType { + def fromJson(json: String): DataType = parseDataType(parse(json)) + + private object JSortedObject { + def unapplySeq(value: JValue): Option[List[(String, JValue)]] = value match { + case JObject(seq) => Some(seq.toList.sortBy(_._1)) + case _ => None + } + } + + // NOTE: Map fields must be sorted in alphabetical order to keep consistent with the Python side. + private def parseDataType(json: JValue): DataType = json match { + case JString(name) => + PrimitiveType.nameToType(name) + + case JSortedObject( + ("containsNull", JBool(n)), + ("elementType", t: JValue), + ("type", JString("array"))) => + ArrayType(parseDataType(t), n) + + case JSortedObject( + ("keyType", k: JValue), + ("type", JString("map")), + ("valueContainsNull", JBool(n)), + ("valueType", v: JValue)) => + MapType(parseDataType(k), parseDataType(v), n) + + case JSortedObject( + ("fields", JArray(fields)), + ("type", JString("struct"))) => + StructType(fields.map(parseStructField)) + + case JSortedObject( + ("class", JString(udtClass)), + ("pyClass", _), + ("sqlType", _), + ("type", JString("udt"))) => + Class.forName(udtClass).newInstance().asInstanceOf[UserDefinedType[_]] + } + + private def parseStructField(json: JValue): StructField = json match { + case JSortedObject( + ("metadata", metadata: JObject), + ("name", JString(name)), + ("nullable", JBool(nullable)), + ("type", dataType: JValue)) => + StructField(name, parseDataType(dataType), nullable, Metadata.fromJObject(metadata)) + // Support reading schema when 'metadata' is missing. + case JSortedObject( + ("name", JString(name)), + ("nullable", JBool(nullable)), + ("type", dataType: JValue)) => + StructField(name, parseDataType(dataType), nullable) + } + + @deprecated("Use DataType.fromJson instead", "1.2.0") + def fromCaseClassString(string: String): DataType = CaseClassStringParser(string) + + private object CaseClassStringParser extends RegexParsers { + protected lazy val primitiveType: Parser[DataType] = + ( "StringType" ^^^ StringType + | "FloatType" ^^^ FloatType + | "IntegerType" ^^^ IntegerType + | "ByteType" ^^^ ByteType + | "ShortType" ^^^ ShortType + | "DoubleType" ^^^ DoubleType + | "LongType" ^^^ LongType + | "BinaryType" ^^^ BinaryType + | "BooleanType" ^^^ BooleanType + | "DateType" ^^^ DateType + | "DecimalType()" ^^^ DecimalType.Unlimited + | fixedDecimalType + | "TimestampType" ^^^ TimestampType + ) + + protected lazy val fixedDecimalType: Parser[DataType] = + ("DecimalType(" ~> "[0-9]+".r) ~ ("," ~> "[0-9]+".r <~ ")") ^^ { + case precision ~ scale => DecimalType(precision.toInt, scale.toInt) + } + + protected lazy val arrayType: Parser[DataType] = + "ArrayType" ~> "(" ~> dataType ~ "," ~ boolVal <~ ")" ^^ { + case tpe ~ _ ~ containsNull => ArrayType(tpe, containsNull) + } + + protected lazy val mapType: Parser[DataType] = + "MapType" ~> "(" ~> dataType ~ "," ~ dataType ~ "," ~ boolVal <~ ")" ^^ { + case t1 ~ _ ~ t2 ~ _ ~ valueContainsNull => MapType(t1, t2, valueContainsNull) + } + + protected lazy val structField: Parser[StructField] = + ("StructField(" ~> "[a-zA-Z0-9_]*".r) ~ ("," ~> dataType) ~ ("," ~> boolVal <~ ")") ^^ { + case name ~ tpe ~ nullable => + StructField(name, tpe, nullable = nullable) + } + + protected lazy val boolVal: Parser[Boolean] = + ( "true" ^^^ true + | "false" ^^^ false + ) + + protected lazy val structType: Parser[DataType] = + "StructType\\([A-zA-z]*\\(".r ~> repsep(structField, ",") <~ "))" ^^ { + case fields => StructType(fields) + } + + protected lazy val dataType: Parser[DataType] = + ( arrayType + | mapType + | structType + | primitiveType + ) + + /** + * Parses a string representation of a DataType. + * + * TODO: Generate parser as pickler... + */ + def apply(asString: String): DataType = parseAll(dataType, asString) match { + case Success(result, _) => result + case failure: NoSuccess => + throw new IllegalArgumentException(s"Unsupported dataType: $asString, $failure") + } + + } + + protected[types] def buildFormattedString( + dataType: DataType, + prefix: String, + builder: StringBuilder): Unit = { + dataType match { + case array: ArrayType => + array.buildFormattedString(prefix, builder) + case struct: StructType => + struct.buildFormattedString(prefix, builder) + case map: MapType => + map.buildFormattedString(prefix, builder) + case _ => + } + } + + /** + * Compares two types, ignoring nullability of ArrayType, MapType, StructType. + */ + private[sql] def equalsIgnoreNullability(left: DataType, right: DataType): Boolean = { + (left, right) match { + case (ArrayType(leftElementType, _), ArrayType(rightElementType, _)) => + equalsIgnoreNullability(leftElementType, rightElementType) + case (MapType(leftKeyType, leftValueType, _), MapType(rightKeyType, rightValueType, _)) => + equalsIgnoreNullability(leftKeyType, rightKeyType) && + equalsIgnoreNullability(leftValueType, rightValueType) + case (StructType(leftFields), StructType(rightFields)) => + leftFields.size == rightFields.size && + leftFields.zip(rightFields) + .forall{ + case (left, right) => + left.name == right.name && equalsIgnoreNullability(left.dataType, right.dataType) + } + case (left, right) => left == right + } + } +} + + +/** + * :: DeveloperApi :: + * + * The base type of all Spark SQL data types. + * + * @group dataType + */ +@DeveloperApi +abstract class DataType { + /** Matches any expression that evaluates to this DataType */ + def unapply(a: Expression): Boolean = a match { + case e: Expression if e.dataType == this => true + case _ => false + } + + def isPrimitive: Boolean = false + + def typeName: String = this.getClass.getSimpleName.stripSuffix("$").dropRight(4).toLowerCase + + private[sql] def jsonValue: JValue = typeName + + def json: String = compact(render(jsonValue)) + + def prettyJson: String = pretty(render(jsonValue)) +} + + +/** + * :: DeveloperApi :: + * + * The data type representing `NULL` values. Please use the singleton [[DataTypes.NullType]]. + * + * @group dataType + */ +@DeveloperApi +case object NullType extends DataType + + +object NativeType { + val all = Seq( + IntegerType, BooleanType, LongType, DoubleType, FloatType, ShortType, ByteType, StringType) + + def unapply(dt: DataType): Boolean = all.contains(dt) + + val defaultSizeOf: Map[NativeType, Int] = Map( + IntegerType -> 4, + BooleanType -> 1, + LongType -> 8, + DoubleType -> 8, + FloatType -> 4, + ShortType -> 2, + ByteType -> 1, + StringType -> 4096) +} + + +trait PrimitiveType extends DataType { + override def isPrimitive = true +} + + +object PrimitiveType { + private val nonDecimals = Seq(NullType, DateType, TimestampType, BinaryType) ++ NativeType.all + private val nonDecimalNameToType = nonDecimals.map(t => t.typeName -> t).toMap + + /** Given the string representation of a type, return its DataType */ + private[sql] def nameToType(name: String): DataType = { + val FIXED_DECIMAL = """decimal\(\s*(\d+)\s*,\s*(\d+)\s*\)""".r + name match { + case "decimal" => DecimalType.Unlimited + case FIXED_DECIMAL(precision, scale) => DecimalType(precision.toInt, scale.toInt) + case other => nonDecimalNameToType(other) + } + } +} + +abstract class NativeType extends DataType { + private[sql] type JvmType + @transient private[sql] val tag: TypeTag[JvmType] + private[sql] val ordering: Ordering[JvmType] + + @transient private[sql] val classTag = ScalaReflectionLock.synchronized { + val mirror = runtimeMirror(Utils.getSparkClassLoader) + ClassTag[JvmType](mirror.runtimeClass(tag.tpe)) + } +} + + +/** + * :: DeveloperApi :: + * + * The data type representing `String` values. Please use the singleton [[DataTypes.StringType]]. + * + * @group dataType + */ +@DeveloperApi +case object StringType extends NativeType with PrimitiveType { + private[sql] type JvmType = String + @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] } + private[sql] val ordering = implicitly[Ordering[JvmType]] +} + + +/** + * :: DeveloperApi :: + * + * The data type representing `Array[Byte]` values. + * Please use the singleton [[DataTypes.BinaryType]]. + * + * @group dataType + */ +@DeveloperApi +case object BinaryType extends NativeType with PrimitiveType { + private[sql] type JvmType = Array[Byte] + @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] } + private[sql] val ordering = new Ordering[JvmType] { + def compare(x: Array[Byte], y: Array[Byte]): Int = { + for (i <- 0 until x.length; if i < y.length) { + val res = x(i).compareTo(y(i)) + if (res != 0) return res + } + x.length - y.length + } + } +} + + +/** + * :: DeveloperApi :: + * + * The data type representing `Boolean` values. Please use the singleton [[DataTypes.BooleanType]]. + * + *@group dataType + */ +@DeveloperApi +case object BooleanType extends NativeType with PrimitiveType { + private[sql] type JvmType = Boolean + @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] } + private[sql] val ordering = implicitly[Ordering[JvmType]] +} + + +/** + * :: DeveloperApi :: + * + * The data type representing `java.sql.Timestamp` values. + * Please use the singleton [[DataTypes.TimestampType]]. + * + * @group dataType + */ +@DeveloperApi +case object TimestampType extends NativeType { + private[sql] type JvmType = Timestamp + + @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] } + + private[sql] val ordering = new Ordering[JvmType] { + def compare(x: Timestamp, y: Timestamp) = x.compareTo(y) + } +} + + +/** + * :: DeveloperApi :: + * + * The data type representing `java.sql.Date` values. + * Please use the singleton [[DataTypes.DateType]]. + * + * @group dataType + */ +@DeveloperApi +case object DateType extends NativeType { + private[sql] type JvmType = Date + + @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] } + + private[sql] val ordering = new Ordering[JvmType] { + def compare(x: Date, y: Date) = x.compareTo(y) + } +} + + +abstract class NumericType extends NativeType with PrimitiveType { + // Unfortunately we can't get this implicitly as that breaks Spark Serialization. In order for + // implicitly[Numeric[JvmType]] to be valid, we have to change JvmType from a type variable to a + // type parameter and and add a numeric annotation (i.e., [JvmType : Numeric]). This gets + // desugared by the compiler into an argument to the objects constructor. This means there is no + // longer an no argument constructor and thus the JVM cannot serialize the object anymore. + private[sql] val numeric: Numeric[JvmType] +} + + +object NumericType { + def unapply(e: Expression): Boolean = e.dataType.isInstanceOf[NumericType] +} + + +/** Matcher for any expressions that evaluate to [[IntegralType]]s */ +object IntegralType { + def unapply(a: Expression): Boolean = a match { + case e: Expression if e.dataType.isInstanceOf[IntegralType] => true + case _ => false + } +} + + +sealed abstract class IntegralType extends NumericType { + private[sql] val integral: Integral[JvmType] +} + + +/** + * :: DeveloperApi :: + * + * The data type representing `Long` values. Please use the singleton [[DataTypes.LongType]]. + * + * @group dataType + */ +@DeveloperApi +case object LongType extends IntegralType { + private[sql] type JvmType = Long + @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] } + private[sql] val numeric = implicitly[Numeric[Long]] + private[sql] val integral = implicitly[Integral[Long]] + private[sql] val ordering = implicitly[Ordering[JvmType]] +} + + +/** + * :: DeveloperApi :: + * + * The data type representing `Int` values. Please use the singleton [[DataTypes.IntegerType]]. + * + * @group dataType + */ +@DeveloperApi +case object IntegerType extends IntegralType { + private[sql] type JvmType = Int + @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] } + private[sql] val numeric = implicitly[Numeric[Int]] + private[sql] val integral = implicitly[Integral[Int]] + private[sql] val ordering = implicitly[Ordering[JvmType]] +} + + +/** + * :: DeveloperApi :: + * + * The data type representing `Short` values. Please use the singleton [[DataTypes.ShortType]]. + * + * @group dataType + */ +@DeveloperApi +case object ShortType extends IntegralType { + private[sql] type JvmType = Short + @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] } + private[sql] val numeric = implicitly[Numeric[Short]] + private[sql] val integral = implicitly[Integral[Short]] + private[sql] val ordering = implicitly[Ordering[JvmType]] +} + + +/** + * :: DeveloperApi :: + * + * The data type representing `Byte` values. Please use the singleton [[DataTypes.ByteType]]. + * + * @group dataType + */ +@DeveloperApi +case object ByteType extends IntegralType { + private[sql] type JvmType = Byte + @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] } + private[sql] val numeric = implicitly[Numeric[Byte]] + private[sql] val integral = implicitly[Integral[Byte]] + private[sql] val ordering = implicitly[Ordering[JvmType]] +} + + +/** Matcher for any expressions that evaluate to [[FractionalType]]s */ +object FractionalType { + def unapply(a: Expression): Boolean = a match { + case e: Expression if e.dataType.isInstanceOf[FractionalType] => true + case _ => false + } +} + + +sealed abstract class FractionalType extends NumericType { + private[sql] val fractional: Fractional[JvmType] + private[sql] val asIntegral: Integral[JvmType] +} + + +/** Precision parameters for a Decimal */ +case class PrecisionInfo(precision: Int, scale: Int) + + +/** + * :: DeveloperApi :: + * + * The data type representing `scala.math.BigDecimal` values. + * A Decimal that might have fixed precision and scale, or unlimited values for these. + * + * Please use [[DataTypes.createDecimalType()]] to create a specific instance. + * + * @group dataType + */ +@DeveloperApi +case class DecimalType(precisionInfo: Option[PrecisionInfo]) extends FractionalType { + private[sql] type JvmType = Decimal + @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] } + private[sql] val numeric = Decimal.DecimalIsFractional + private[sql] val fractional = Decimal.DecimalIsFractional + private[sql] val ordering = Decimal.DecimalIsFractional + private[sql] val asIntegral = Decimal.DecimalAsIfIntegral + + def precision: Int = precisionInfo.map(_.precision).getOrElse(-1) + + def scale: Int = precisionInfo.map(_.scale).getOrElse(-1) + + override def typeName: String = precisionInfo match { + case Some(PrecisionInfo(precision, scale)) => s"decimal($precision,$scale)" + case None => "decimal" + } + + override def toString: String = precisionInfo match { + case Some(PrecisionInfo(precision, scale)) => s"DecimalType($precision,$scale)" + case None => "DecimalType()" + } +} + + +/** Extra factory methods and pattern matchers for Decimals */ +object DecimalType { + val Unlimited: DecimalType = DecimalType(None) + + object Fixed { + def unapply(t: DecimalType): Option[(Int, Int)] = + t.precisionInfo.map(p => (p.precision, p.scale)) + } + + object Expression { + def unapply(e: Expression): Option[(Int, Int)] = e.dataType match { + case t: DecimalType => t.precisionInfo.map(p => (p.precision, p.scale)) + case _ => None + } + } + + def apply(): DecimalType = Unlimited + + def apply(precision: Int, scale: Int): DecimalType = + DecimalType(Some(PrecisionInfo(precision, scale))) + + def unapply(t: DataType): Boolean = t.isInstanceOf[DecimalType] + + def unapply(e: Expression): Boolean = e.dataType.isInstanceOf[DecimalType] + + def isFixed(dataType: DataType): Boolean = dataType match { + case DecimalType.Fixed(_, _) => true + case _ => false + } +} + + +/** + * :: DeveloperApi :: + * + * The data type representing `Double` values. Please use the singleton [[DataTypes.DoubleType]]. + * + * @group dataType + */ +@DeveloperApi +case object DoubleType extends FractionalType { + private[sql] type JvmType = Double + @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] } + private[sql] val numeric = implicitly[Numeric[Double]] + private[sql] val fractional = implicitly[Fractional[Double]] + private[sql] val ordering = implicitly[Ordering[JvmType]] + private[sql] val asIntegral = DoubleAsIfIntegral +} + + +/** + * :: DeveloperApi :: + * + * The data type representing `Float` values. Please use the singleton [[DataTypes.FloatType]]. + * + * @group dataType + */ +@DeveloperApi +case object FloatType extends FractionalType { + private[sql] type JvmType = Float + @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] } + private[sql] val numeric = implicitly[Numeric[Float]] + private[sql] val fractional = implicitly[Fractional[Float]] + private[sql] val ordering = implicitly[Ordering[JvmType]] + private[sql] val asIntegral = FloatAsIfIntegral +} + + +object ArrayType { + /** Construct a [[ArrayType]] object with the given element type. The `containsNull` is true. */ + def apply(elementType: DataType): ArrayType = ArrayType(elementType, true) +} + + +/** + * :: DeveloperApi :: + * + * The data type for collections of multiple values. + * Internally these are represented as columns that contain a ``scala.collection.Seq``. + * + * Please use [[DataTypes.createArrayType()]] to create a specific instance. + * + * An [[ArrayType]] object comprises two fields, `elementType: [[DataType]]` and + * `containsNull: Boolean`. The field of `elementType` is used to specify the type of + * array elements. The field of `containsNull` is used to specify if the array has `null` values. + * + * @param elementType The data type of values. + * @param containsNull Indicates if values have `null` values + * + * @group dataType + */ +@DeveloperApi +case class ArrayType(elementType: DataType, containsNull: Boolean) extends DataType { + private[sql] def buildFormattedString(prefix: String, builder: StringBuilder): Unit = { + builder.append( + s"$prefix-- element: ${elementType.typeName} (containsNull = $containsNull)\n") + DataType.buildFormattedString(elementType, s"$prefix |", builder) + } + + override private[sql] def jsonValue = + ("type" -> typeName) ~ + ("elementType" -> elementType.jsonValue) ~ + ("containsNull" -> containsNull) +} + + +/** + * A field inside a StructType. + * + * @param name The name of this field. + * @param dataType The data type of this field. + * @param nullable Indicates if values of this field can be `null` values. + * @param metadata The metadata of this field. The metadata should be preserved during + * transformation if the content of the column is not modified, e.g, in selection. + */ +case class StructField( + name: String, + dataType: DataType, + nullable: Boolean = true, + metadata: Metadata = Metadata.empty) { + + private[sql] def buildFormattedString(prefix: String, builder: StringBuilder): Unit = { + builder.append(s"$prefix-- $name: ${dataType.typeName} (nullable = $nullable)\n") + DataType.buildFormattedString(dataType, s"$prefix |", builder) + } + + // override the default toString to be compatible with legacy parquet files. + override def toString: String = s"StructField($name,$dataType,$nullable)" + + private[sql] def jsonValue: JValue = { + ("name" -> name) ~ + ("type" -> dataType.jsonValue) ~ + ("nullable" -> nullable) ~ + ("metadata" -> metadata.jsonValue) + } +} + + +object StructType { + protected[sql] def fromAttributes(attributes: Seq[Attribute]): StructType = + StructType(attributes.map(a => StructField(a.name, a.dataType, a.nullable, a.metadata))) + + def apply(fields: Seq[StructField]): StructType = StructType(fields.toArray) + + def apply(fields: java.util.List[StructField]): StructType = { + StructType(fields.toArray.asInstanceOf[Array[StructField]]) + } +} + + +/** + * :: DeveloperApi :: + * + * A [[StructType]] object can be constructed by + * {{{ + * StructType(fields: Seq[StructField]) + * }}} + * For a [[StructType]] object, one or multiple [[StructField]]s can be extracted by names. + * If multiple [[StructField]]s are extracted, a [[StructType]] object will be returned. + * If a provided name does not have a matching field, it will be ignored. For the case + * of extracting a single StructField, a `null` will be returned. + * Example: + * {{{ + * import org.apache.spark.sql._ + * + * val struct = + * StructType( + * StructField("a", IntegerType, true) :: + * StructField("b", LongType, false) :: + * StructField("c", BooleanType, false) :: Nil) + * + * // Extract a single StructField. + * val singleField = struct("b") + * // singleField: StructField = StructField(b,LongType,false) + * + * // This struct does not have a field called "d". null will be returned. + * val nonExisting = struct("d") + * // nonExisting: StructField = null + * + * // Extract multiple StructFields. Field names are provided in a set. + * // A StructType object will be returned. + * val twoFields = struct(Set("b", "c")) + * // twoFields: StructType = + * // StructType(List(StructField(b,LongType,false), StructField(c,BooleanType,false))) + * + * // Any names without matching fields will be ignored. + * // For the case shown below, "d" will be ignored and + * // it is treated as struct(Set("b", "c")). + * val ignoreNonExisting = struct(Set("b", "c", "d")) + * // ignoreNonExisting: StructType = + * // StructType(List(StructField(b,LongType,false), StructField(c,BooleanType,false))) + * }}} + * + * A [[org.apache.spark.sql.catalyst.expressions.Row]] object is used as a value of the StructType. + * Example: + * {{{ + * import org.apache.spark.sql._ + * + * val innerStruct = + * StructType( + * StructField("f1", IntegerType, true) :: + * StructField("f2", LongType, false) :: + * StructField("f3", BooleanType, false) :: Nil) + * + * val struct = StructType( + * StructField("a", innerStruct, true) :: Nil) + * + * // Create a Row with the schema defined by struct + * val row = Row(Row(1, 2, true)) + * // row: Row = [[1,2,true]] + * }}} + * + * @group dataType + */ +@DeveloperApi +case class StructType(fields: Array[StructField]) extends DataType with Seq[StructField] { + + /** Returns all field names in an array. */ + def fieldNames: Array[String] = fields.map(_.name) + + private lazy val fieldNamesSet: Set[String] = fieldNames.toSet + private lazy val nameToField: Map[String, StructField] = fields.map(f => f.name -> f).toMap + + /** + * Extracts a [[StructField]] of the given name. If the [[StructType]] object does not + * have a name matching the given name, `null` will be returned. + */ + def apply(name: String): StructField = { + nameToField.getOrElse(name, throw new IllegalArgumentException(s"Field $name does not exist.")) + } + + /** + * Returns a [[StructType]] containing [[StructField]]s of the given names, preserving the + * original order of fields. Those names which do not have matching fields will be ignored. + */ + def apply(names: Set[String]): StructType = { + val nonExistFields = names -- fieldNamesSet + if (nonExistFields.nonEmpty) { + throw new IllegalArgumentException( + s"Field ${nonExistFields.mkString(",")} does not exist.") + } + // Preserve the original order of fields. + StructType(fields.filter(f => names.contains(f.name))) + } + + protected[sql] def toAttributes: Seq[AttributeReference] = + map(f => AttributeReference(f.name, f.dataType, f.nullable, f.metadata)()) + + def treeString: String = { + val builder = new StringBuilder + builder.append("root\n") + val prefix = " |" + fields.foreach(field => field.buildFormattedString(prefix, builder)) + + builder.toString() + } + + def printTreeString(): Unit = println(treeString) + + private[sql] def buildFormattedString(prefix: String, builder: StringBuilder): Unit = { + fields.foreach(field => field.buildFormattedString(prefix, builder)) + } + + override private[sql] def jsonValue = + ("type" -> typeName) ~ + ("fields" -> map(_.jsonValue)) + + override def apply(fieldIndex: Int): StructField = fields(fieldIndex) + + override def length: Int = fields.length + + override def iterator: Iterator[StructField] = fields.iterator +} + + +object MapType { + /** + * Construct a [[MapType]] object with the given key type and value type. + * The `valueContainsNull` is true. + */ + def apply(keyType: DataType, valueType: DataType): MapType = + MapType(keyType: DataType, valueType: DataType, valueContainsNull = true) +} + + +/** + * :: DeveloperApi :: + * + * The data type for Maps. Keys in a map are not allowed to have `null` values. + * + * Please use [[DataTypes.createMapType()]] to create a specific instance. + * + * @param keyType The data type of map keys. + * @param valueType The data type of map values. + * @param valueContainsNull Indicates if map values have `null` values. + * + * @group dataType + */ +case class MapType( + keyType: DataType, + valueType: DataType, + valueContainsNull: Boolean) extends DataType { + private[sql] def buildFormattedString(prefix: String, builder: StringBuilder): Unit = { + builder.append(s"$prefix-- key: ${keyType.typeName}\n") + builder.append(s"$prefix-- value: ${valueType.typeName} " + + s"(valueContainsNull = $valueContainsNull)\n") + DataType.buildFormattedString(keyType, s"$prefix |", builder) + DataType.buildFormattedString(valueType, s"$prefix |", builder) + } + + override private[sql] def jsonValue: JValue = + ("type" -> typeName) ~ + ("keyType" -> keyType.jsonValue) ~ + ("valueType" -> valueType.jsonValue) ~ + ("valueContainsNull" -> valueContainsNull) +} + + +/** + * ::DeveloperApi:: + * The data type for User Defined Types (UDTs). + * + * This interface allows a user to make their own classes more interoperable with SparkSQL; + * e.g., by creating a [[UserDefinedType]] for a class X, it becomes possible to create + * a SchemaRDD which has class X in the schema. + * + * For SparkSQL to recognize UDTs, the UDT must be annotated with + * [[SQLUserDefinedType]]. + * + * The conversion via `serialize` occurs when instantiating a `SchemaRDD` from another RDD. + * The conversion via `deserialize` occurs when reading from a `SchemaRDD`. + */ +@DeveloperApi +abstract class UserDefinedType[UserType] extends DataType with Serializable { + + /** Underlying storage type for this UDT */ + def sqlType: DataType + + /** Paired Python UDT class, if exists. */ + def pyUDT: String = null + + /** + * Convert the user type to a SQL datum + * + * TODO: Can we make this take obj: UserType? The issue is in ScalaReflection.convertToCatalyst, + * where we need to convert Any to UserType. + */ + def serialize(obj: Any): Any + + /** Convert a SQL datum to the user type */ + def deserialize(datum: Any): UserType + + override private[sql] def jsonValue: JValue = { + ("type" -> "udt") ~ + ("class" -> this.getClass.getName) ~ + ("pyClass" -> pyUDT) ~ + ("sqlType" -> sqlType.jsonValue) + } + + /** + * Class object for the UserType + */ + def userClass: java.lang.Class[UserType] +} http://git-wip-us.apache.org/repos/asf/spark/blob/f9969098/sql/catalyst/src/main/scala/org/apache/spark/sql/types/decimal/Decimal.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/decimal/Decimal.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/decimal/Decimal.scala new file mode 100644 index 0000000..c7864d1 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/decimal/Decimal.scala @@ -0,0 +1,335 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.types.decimal + +import org.apache.spark.annotation.DeveloperApi + +/** + * A mutable implementation of BigDecimal that can hold a Long if values are small enough. + * + * The semantics of the fields are as follows: + * - _precision and _scale represent the SQL precision and scale we are looking for + * - If decimalVal is set, it represents the whole decimal value + * - Otherwise, the decimal value is longVal / (10 ** _scale) + */ +final class Decimal extends Ordered[Decimal] with Serializable { + import Decimal.{MAX_LONG_DIGITS, POW_10, ROUNDING_MODE, BIG_DEC_ZERO} + + private var decimalVal: BigDecimal = null + private var longVal: Long = 0L + private var _precision: Int = 1 + private var _scale: Int = 0 + + def precision: Int = _precision + def scale: Int = _scale + + /** + * Set this Decimal to the given Long. Will have precision 20 and scale 0. + */ + def set(longVal: Long): Decimal = { + if (longVal <= -POW_10(MAX_LONG_DIGITS) || longVal >= POW_10(MAX_LONG_DIGITS)) { + // We can't represent this compactly as a long without risking overflow + this.decimalVal = BigDecimal(longVal) + this.longVal = 0L + } else { + this.decimalVal = null + this.longVal = longVal + } + this._precision = 20 + this._scale = 0 + this + } + + /** + * Set this Decimal to the given Int. Will have precision 10 and scale 0. + */ + def set(intVal: Int): Decimal = { + this.decimalVal = null + this.longVal = intVal + this._precision = 10 + this._scale = 0 + this + } + + /** + * Set this Decimal to the given unscaled Long, with a given precision and scale. + */ + def set(unscaled: Long, precision: Int, scale: Int): Decimal = { + if (setOrNull(unscaled, precision, scale) == null) { + throw new IllegalArgumentException("Unscaled value too large for precision") + } + this + } + + /** + * Set this Decimal to the given unscaled Long, with a given precision and scale, + * and return it, or return null if it cannot be set due to overflow. + */ + def setOrNull(unscaled: Long, precision: Int, scale: Int): Decimal = { + if (unscaled <= -POW_10(MAX_LONG_DIGITS) || unscaled >= POW_10(MAX_LONG_DIGITS)) { + // We can't represent this compactly as a long without risking overflow + if (precision < 19) { + return null // Requested precision is too low to represent this value + } + this.decimalVal = BigDecimal(longVal) + this.longVal = 0L + } else { + val p = POW_10(math.min(precision, MAX_LONG_DIGITS)) + if (unscaled <= -p || unscaled >= p) { + return null // Requested precision is too low to represent this value + } + this.decimalVal = null + this.longVal = unscaled + } + this._precision = precision + this._scale = scale + this + } + + /** + * Set this Decimal to the given BigDecimal value, with a given precision and scale. + */ + def set(decimal: BigDecimal, precision: Int, scale: Int): Decimal = { + this.decimalVal = decimal.setScale(scale, ROUNDING_MODE) + require(decimalVal.precision <= precision, "Overflowed precision") + this.longVal = 0L + this._precision = precision + this._scale = scale + this + } + + /** + * Set this Decimal to the given BigDecimal value, inheriting its precision and scale. + */ + def set(decimal: BigDecimal): Decimal = { + this.decimalVal = decimal + this.longVal = 0L + this._precision = decimal.precision + this._scale = decimal.scale + this + } + + /** + * Set this Decimal to the given Decimal value. + */ + def set(decimal: Decimal): Decimal = { + this.decimalVal = decimal.decimalVal + this.longVal = decimal.longVal + this._precision = decimal._precision + this._scale = decimal._scale + this + } + + def toBigDecimal: BigDecimal = { + if (decimalVal.ne(null)) { + decimalVal + } else { + BigDecimal(longVal, _scale) + } + } + + def toUnscaledLong: Long = { + if (decimalVal.ne(null)) { + decimalVal.underlying().unscaledValue().longValue() + } else { + longVal + } + } + + override def toString: String = toBigDecimal.toString() + + @DeveloperApi + def toDebugString: String = { + if (decimalVal.ne(null)) { + s"Decimal(expanded,$decimalVal,$precision,$scale})" + } else { + s"Decimal(compact,$longVal,$precision,$scale})" + } + } + + def toDouble: Double = toBigDecimal.doubleValue() + + def toFloat: Float = toBigDecimal.floatValue() + + def toLong: Long = { + if (decimalVal.eq(null)) { + longVal / POW_10(_scale) + } else { + decimalVal.longValue() + } + } + + def toInt: Int = toLong.toInt + + def toShort: Short = toLong.toShort + + def toByte: Byte = toLong.toByte + + /** + * Update precision and scale while keeping our value the same, and return true if successful. + * + * @return true if successful, false if overflow would occur + */ + def changePrecision(precision: Int, scale: Int): Boolean = { + // First, update our longVal if we can, or transfer over to using a BigDecimal + if (decimalVal.eq(null)) { + if (scale < _scale) { + // Easier case: we just need to divide our scale down + val diff = _scale - scale + val droppedDigits = longVal % POW_10(diff) + longVal /= POW_10(diff) + if (math.abs(droppedDigits) * 2 >= POW_10(diff)) { + longVal += (if (longVal < 0) -1L else 1L) + } + } else if (scale > _scale) { + // We might be able to multiply longVal by a power of 10 and not overflow, but if not, + // switch to using a BigDecimal + val diff = scale - _scale + val p = POW_10(math.max(MAX_LONG_DIGITS - diff, 0)) + if (diff <= MAX_LONG_DIGITS && longVal > -p && longVal < p) { + // Multiplying longVal by POW_10(diff) will still keep it below MAX_LONG_DIGITS + longVal *= POW_10(diff) + } else { + // Give up on using Longs; switch to BigDecimal, which we'll modify below + decimalVal = BigDecimal(longVal, _scale) + } + } + // In both cases, we will check whether our precision is okay below + } + + if (decimalVal.ne(null)) { + // We get here if either we started with a BigDecimal, or we switched to one because we would + // have overflowed our Long; in either case we must rescale decimalVal to the new scale. + val newVal = decimalVal.setScale(scale, ROUNDING_MODE) + if (newVal.precision > precision) { + return false + } + decimalVal = newVal + } else { + // We're still using Longs, but we should check whether we match the new precision + val p = POW_10(math.min(_precision, MAX_LONG_DIGITS)) + if (longVal <= -p || longVal >= p) { + // Note that we shouldn't have been able to fix this by switching to BigDecimal + return false + } + } + + _precision = precision + _scale = scale + true + } + + override def clone(): Decimal = new Decimal().set(this) + + override def compare(other: Decimal): Int = { + if (decimalVal.eq(null) && other.decimalVal.eq(null) && _scale == other._scale) { + if (longVal < other.longVal) -1 else if (longVal == other.longVal) 0 else 1 + } else { + toBigDecimal.compare(other.toBigDecimal) + } + } + + override def equals(other: Any) = other match { + case d: Decimal => + compare(d) == 0 + case _ => + false + } + + override def hashCode(): Int = toBigDecimal.hashCode() + + def isZero: Boolean = if (decimalVal.ne(null)) decimalVal == BIG_DEC_ZERO else longVal == 0 + + def + (that: Decimal): Decimal = Decimal(toBigDecimal + that.toBigDecimal) + + def - (that: Decimal): Decimal = Decimal(toBigDecimal - that.toBigDecimal) + + def * (that: Decimal): Decimal = Decimal(toBigDecimal * that.toBigDecimal) + + def / (that: Decimal): Decimal = + if (that.isZero) null else Decimal(toBigDecimal / that.toBigDecimal) + + def % (that: Decimal): Decimal = + if (that.isZero) null else Decimal(toBigDecimal % that.toBigDecimal) + + def remainder(that: Decimal): Decimal = this % that + + def unary_- : Decimal = { + if (decimalVal.ne(null)) { + Decimal(-decimalVal) + } else { + Decimal(-longVal, precision, scale) + } + } +} + +object Decimal { + private val ROUNDING_MODE = BigDecimal.RoundingMode.HALF_UP + + /** Maximum number of decimal digits a Long can represent */ + val MAX_LONG_DIGITS = 18 + + private val POW_10 = Array.tabulate[Long](MAX_LONG_DIGITS + 1)(i => math.pow(10, i).toLong) + + private val BIG_DEC_ZERO = BigDecimal(0) + + def apply(value: Double): Decimal = new Decimal().set(value) + + def apply(value: Long): Decimal = new Decimal().set(value) + + def apply(value: Int): Decimal = new Decimal().set(value) + + def apply(value: BigDecimal): Decimal = new Decimal().set(value) + + def apply(value: BigDecimal, precision: Int, scale: Int): Decimal = + new Decimal().set(value, precision, scale) + + def apply(unscaled: Long, precision: Int, scale: Int): Decimal = + new Decimal().set(unscaled, precision, scale) + + def apply(value: String): Decimal = new Decimal().set(BigDecimal(value)) + + // Evidence parameters for Decimal considered either as Fractional or Integral. We provide two + // parameters inheriting from a common trait since both traits define mkNumericOps. + // See scala.math's Numeric.scala for examples for Scala's built-in types. + + /** Common methods for Decimal evidence parameters */ + trait DecimalIsConflicted extends Numeric[Decimal] { + override def plus(x: Decimal, y: Decimal): Decimal = x + y + override def times(x: Decimal, y: Decimal): Decimal = x * y + override def minus(x: Decimal, y: Decimal): Decimal = x - y + override def negate(x: Decimal): Decimal = -x + override def toDouble(x: Decimal): Double = x.toDouble + override def toFloat(x: Decimal): Float = x.toFloat + override def toInt(x: Decimal): Int = x.toInt + override def toLong(x: Decimal): Long = x.toLong + override def fromInt(x: Int): Decimal = new Decimal().set(x) + override def compare(x: Decimal, y: Decimal): Int = x.compare(y) + } + + /** A [[scala.math.Fractional]] evidence parameter for Decimals. */ + object DecimalIsFractional extends DecimalIsConflicted with Fractional[Decimal] { + override def div(x: Decimal, y: Decimal): Decimal = x / y + } + + /** A [[scala.math.Integral]] evidence parameter for Decimals. */ + object DecimalAsIfIntegral extends DecimalIsConflicted with Integral[Decimal] { + override def quot(x: Decimal, y: Decimal): Decimal = x / y + override def rem(x: Decimal, y: Decimal): Decimal = x % y + } +} http://git-wip-us.apache.org/repos/asf/spark/blob/f9969098/sql/catalyst/src/main/scala/org/apache/spark/sql/types/package.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/package.scala new file mode 100644 index 0000000..346a51e --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/package.scala @@ -0,0 +1,24 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +/** + * Contains a type system for attributes produced by relations, including complex types like + * structs, arrays and maps. + */ +package object types http://git-wip-us.apache.org/repos/asf/spark/blob/f9969098/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala index 7be24be..117725d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala @@ -23,7 +23,7 @@ import java.sql.{Date, Timestamp} import org.scalatest.FunSuite import org.apache.spark.sql.catalyst.expressions.Row -import org.apache.spark.sql.catalyst.types._ +import org.apache.spark.sql.types._ case class PrimitiveData( intField: Int, --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org