Repository: spark Updated Branches: refs/heads/master 2848f4da4 -> ab535b9a1
[SPARK-8226] [SQL] Add function shiftrightunsigned Author: zhichao.li <zhichao...@intel.com> Closes #7035 from zhichao-li/shiftRightUnsigned and squashes the following commits: 6bcca5a [zhichao.li] change coding style 3e9f5ae [zhichao.li] python style d85ae0b [zhichao.li] add shiftrightunsigned Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/ab535b9a Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/ab535b9a Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/ab535b9a Branch: refs/heads/master Commit: ab535b9a1dab40ea7335ff9abb9b522fc2b5ed66 Parents: 2848f4d Author: zhichao.li <zhichao...@intel.com> Authored: Fri Jul 3 15:39:16 2015 -0700 Committer: Davies Liu <davies....@gmail.com> Committed: Fri Jul 3 15:39:16 2015 -0700 ---------------------------------------------------------------------- python/pyspark/sql/functions.py | 13 ++++++ .../catalyst/analysis/FunctionRegistry.scala | 1 + .../spark/sql/catalyst/expressions/math.scala | 49 ++++++++++++++++++++ .../expressions/MathFunctionsSuite.scala | 13 ++++++ .../scala/org/apache/spark/sql/functions.scala | 20 ++++++++ .../apache/spark/sql/MathExpressionsSuite.scala | 17 +++++++ 6 files changed, 113 insertions(+) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/ab535b9a/python/pyspark/sql/functions.py ---------------------------------------------------------------------- diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 12263e6..69e563e 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -436,6 +436,19 @@ def shiftRight(col, numBits): return Column(jc) +@since(1.5) +def shiftRightUnsigned(col, numBits): + """Unsigned shift the the given value numBits right. + + >>> sqlContext.createDataFrame([(-42,)], ['a']).select(shiftRightUnsigned('a', 1).alias('r'))\ + .collect() + [Row(r=9223372036854775787)] + """ + sc = SparkContext._active_spark_context + jc = sc._jvm.functions.shiftRightUnsigned(_to_java_column(col), numBits) + return Column(jc) + + @since(1.4) def sparkPartitionId(): """A column for partition ID of the Spark task. http://git-wip-us.apache.org/repos/asf/spark/blob/ab535b9a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 9163b03..cd5ba12 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -129,6 +129,7 @@ object FunctionRegistry { expression[Rint]("rint"), expression[ShiftLeft]("shiftleft"), expression[ShiftRight]("shiftright"), + expression[ShiftRightUnsigned]("shiftrightunsigned"), expression[Signum]("sign"), expression[Signum]("signum"), expression[Sin]("sin"), http://git-wip-us.apache.org/repos/asf/spark/blob/ab535b9a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala index 273a6c5..0fc320f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala @@ -521,6 +521,55 @@ case class ShiftRight(left: Expression, right: Expression) extends BinaryExpress } } +case class ShiftRightUnsigned(left: Expression, right: Expression) extends BinaryExpression { + + override def checkInputDataTypes(): TypeCheckResult = { + (left.dataType, right.dataType) match { + case (NullType, _) | (_, NullType) => return TypeCheckResult.TypeCheckSuccess + case (_, IntegerType) => left.dataType match { + case LongType | IntegerType | ShortType | ByteType => + return TypeCheckResult.TypeCheckSuccess + case _ => // failed + } + case _ => // failed + } + TypeCheckResult.TypeCheckFailure( + s"ShiftRightUnsigned expects long, integer, short or byte value as first argument and an " + + s"integer value as second argument, not (${left.dataType}, ${right.dataType})") + } + + override def eval(input: InternalRow): Any = { + val valueLeft = left.eval(input) + if (valueLeft != null) { + val valueRight = right.eval(input) + if (valueRight != null) { + valueLeft match { + case l: Long => l >>> valueRight.asInstanceOf[Integer] + case i: Integer => i >>> valueRight.asInstanceOf[Integer] + case s: Short => s >>> valueRight.asInstanceOf[Integer] + case b: Byte => b >>> valueRight.asInstanceOf[Integer] + } + } else { + null + } + } else { + null + } + } + + override def dataType: DataType = { + left.dataType match { + case LongType => LongType + case IntegerType | ShortType | ByteType => IntegerType + case _ => NullType + } + } + + override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + nullSafeCodeGen(ctx, ev, (result, left, right) => s"$result = $left >>> $right;") + } +} + /** * Performs the inverse operation of HEX. * Resulting characters are returned as a byte array. http://git-wip-us.apache.org/repos/asf/spark/blob/ab535b9a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala index 8457864..20839c8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala @@ -264,6 +264,19 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(ShiftRight(Literal(-42.toLong), Literal(1)), -21.toLong) } + test("shift right unsigned") { + checkEvaluation(ShiftRightUnsigned(Literal.create(null, IntegerType), Literal(1)), null) + checkEvaluation(ShiftRightUnsigned(Literal(42), Literal.create(null, IntegerType)), null) + checkEvaluation( + ShiftRight(Literal.create(null, IntegerType), Literal.create(null, IntegerType)), null) + checkEvaluation(ShiftRightUnsigned(Literal(42), Literal(1)), 21) + checkEvaluation(ShiftRightUnsigned(Literal(42.toByte), Literal(1)), 21) + checkEvaluation(ShiftRightUnsigned(Literal(42.toShort), Literal(1)), 21) + checkEvaluation(ShiftRightUnsigned(Literal(42.toLong), Literal(1)), 21.toLong) + + checkEvaluation(ShiftRightUnsigned(Literal(-42.toLong), Literal(1)), 9223372036854775787L) + } + test("hex") { checkEvaluation(Hex(Literal(28)), "1C") checkEvaluation(Hex(Literal(-28)), "FFFFFFFFFFFFFFE4") http://git-wip-us.apache.org/repos/asf/spark/blob/ab535b9a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 0d5d49c..4b70dc5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -1344,6 +1344,26 @@ object functions { def shiftRight(e: Column, numBits: Int): Column = ShiftRight(e.expr, lit(numBits).expr) /** + * Unsigned shift the the given value numBits right. If the given value is a long value, + * it will return a long value else it will return an integer value. + * + * @group math_funcs + * @since 1.5.0 + */ + def shiftRightUnsigned(columnName: String, numBits: Int): Column = + shiftRightUnsigned(Column(columnName), numBits) + + /** + * Unsigned shift the the given value numBits right. If the given value is a long value, + * it will return a long value else it will return an integer value. + * + * @group math_funcs + * @since 1.5.0 + */ + def shiftRightUnsigned(e: Column, numBits: Int): Column = + ShiftRightUnsigned(e.expr, lit(numBits).expr) + + /** * Shift the the given value numBits right. If the given value is a long value, it will return * a long value else it will return an integer value. * http://git-wip-us.apache.org/repos/asf/spark/blob/ab535b9a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala index dc8f994..24bef21 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala @@ -304,6 +304,23 @@ class MathExpressionsSuite extends QueryTest { Row(21.toLong, 21, 21.toShort, 21.toByte, null)) } + test("shift right unsigned") { + val df = Seq[(Long, Integer, Short, Byte, Integer, Integer)]((-42, 42, 42, 42, 42, null)) + .toDF("a", "b", "c", "d", "e", "f") + + checkAnswer( + df.select( + shiftRightUnsigned('a, 1), shiftRightUnsigned('b, 1), shiftRightUnsigned('c, 1), + shiftRightUnsigned('d, 1), shiftRightUnsigned('f, 1)), + Row(9223372036854775787L, 21, 21.toShort, 21.toByte, null)) + + checkAnswer( + df.selectExpr( + "shiftRightUnsigned(a, 1)", "shiftRightUnsigned(b, 1)", "shiftRightUnsigned(c, 1)", + "shiftRightUnsigned(d, 1)", "shiftRightUnsigned(f, 1)"), + Row(9223372036854775787L, 21, 21.toShort, 21.toByte, null)) + } + test("binary log") { val df = Seq[(Integer, Integer)]((123, null)).toDF("a", "b") checkAnswer( --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org