Repository: spark Updated Branches: refs/heads/master 1dfb0f7b2 -> 860a49ef2
[SPARK-7153] [SQL] support all integral type ordinal in GetArrayItem first convert `ordinal` to `Number`, then convert to int type. Author: Wenchen Fan <cloud0...@outlook.com> Closes #5706 from cloud-fan/7153 and squashes the following commits: 915db79 [Wenchen Fan] fix 7153 Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/860a49ef Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/860a49ef Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/860a49ef Branch: refs/heads/master Commit: 860a49ef20cea5711a7f54de0053ea33647e56a7 Parents: 1dfb0f7 Author: Wenchen Fan <cloud0...@outlook.com> Authored: Mon Jun 22 17:37:35 2015 -0700 Committer: Michael Armbrust <mich...@databricks.com> Committed: Mon Jun 22 17:37:35 2015 -0700 ---------------------------------------------------------------------- .../sql/catalyst/expressions/ExtractValue.scala | 2 +- .../expressions/complexTypeCreator.scala | 80 +++++++++++++++++++ .../sql/catalyst/expressions/complexTypes.scala | 81 -------------------- .../catalyst/expressions/ComplexTypeSuite.scala | 20 +++++ 4 files changed, 101 insertions(+), 82 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/860a49ef/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExtractValue.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExtractValue.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExtractValue.scala index 013027b..4d6c1c2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExtractValue.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExtractValue.scala @@ -186,7 +186,7 @@ case class GetArrayItem(child: Expression, ordinal: Expression) // TODO: consider using Array[_] for ArrayType child to avoid // boxing of primitives val baseValue = value.asInstanceOf[Seq[_]] - val index = ordinal.asInstanceOf[Int] + val index = ordinal.asInstanceOf[Number].intValue() if (index >= baseValue.size || index < 0) { null } else { http://git-wip-us.apache.org/repos/asf/spark/blob/860a49ef/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala new file mode 100644 index 0000000..e0bf07e --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.sql.types._ + + +/** + * Returns an Array containing the evaluation of all children expressions. + */ +case class CreateArray(children: Seq[Expression]) extends Expression { + + override def foldable: Boolean = children.forall(_.foldable) + + lazy val childTypes = children.map(_.dataType).distinct + + override lazy val resolved = + childrenResolved && childTypes.size <= 1 + + override def dataType: DataType = { + assert(resolved, s"Invalid dataType of mixed ArrayType ${childTypes.mkString(",")}") + ArrayType( + childTypes.headOption.getOrElse(NullType), + containsNull = children.exists(_.nullable)) + } + + override def nullable: Boolean = false + + override def eval(input: InternalRow): Any = { + children.map(_.eval(input)) + } + + override def toString: String = s"Array(${children.mkString(",")})" +} + +/** + * Returns a Row containing the evaluation of all children expressions. + * TODO: [[CreateStruct]] does not support codegen. + */ +case class CreateStruct(children: Seq[Expression]) extends Expression { + + override def foldable: Boolean = children.forall(_.foldable) + + override lazy val resolved: Boolean = childrenResolved + + override lazy val dataType: StructType = { + assert(resolved, + s"CreateStruct contains unresolvable children: ${children.filterNot(_.resolved)}.") + val fields = children.zipWithIndex.map { case (child, idx) => + child match { + case ne: NamedExpression => + StructField(ne.name, ne.dataType, ne.nullable, ne.metadata) + case _ => + StructField(s"col${idx + 1}", child.dataType, child.nullable, Metadata.empty) + } + } + StructType(fields) + } + + override def nullable: Boolean = false + + override def eval(input: InternalRow): Any = { + InternalRow(children.map(_.eval(input)): _*) + } +} http://git-wip-us.apache.org/repos/asf/spark/blob/860a49ef/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala deleted file mode 100644 index 72fdceb..0000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala +++ /dev/null @@ -1,81 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.expressions - -import org.apache.spark.sql.catalyst -import org.apache.spark.sql.types._ - - -/** - * Returns an Array containing the evaluation of all children expressions. - */ -case class CreateArray(children: Seq[Expression]) extends Expression { - - override def foldable: Boolean = children.forall(_.foldable) - - lazy val childTypes = children.map(_.dataType).distinct - - override lazy val resolved = - childrenResolved && childTypes.size <= 1 - - override def dataType: DataType = { - assert(resolved, s"Invalid dataType of mixed ArrayType ${childTypes.mkString(",")}") - ArrayType( - childTypes.headOption.getOrElse(NullType), - containsNull = children.exists(_.nullable)) - } - - override def nullable: Boolean = false - - override def eval(input: InternalRow): Any = { - children.map(_.eval(input)) - } - - override def toString: String = s"Array(${children.mkString(",")})" -} - -/** - * Returns a Row containing the evaluation of all children expressions. - * TODO: [[CreateStruct]] does not support codegen. - */ -case class CreateStruct(children: Seq[Expression]) extends Expression { - - override def foldable: Boolean = children.forall(_.foldable) - - override lazy val resolved: Boolean = childrenResolved - - override lazy val dataType: StructType = { - assert(resolved, - s"CreateStruct contains unresolvable children: ${children.filterNot(_.resolved)}.") - val fields = children.zipWithIndex.map { case (child, idx) => - child match { - case ne: NamedExpression => - StructField(ne.name, ne.dataType, ne.nullable, ne.metadata) - case _ => - StructField(s"col${idx + 1}", child.dataType, child.nullable, Metadata.empty) - } - } - StructType(fields) - } - - override def nullable: Boolean = false - - override def eval(input: InternalRow): Any = { - InternalRow(children.map(_.eval(input)): _*) - } -} http://git-wip-us.apache.org/repos/asf/spark/blob/860a49ef/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala index 2b0f461..b80911e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala @@ -26,6 +26,26 @@ import org.apache.spark.unsafe.types.UTF8String class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper { + /** + * Runs through the testFunc for all integral data types. + * + * @param testFunc a test function that accepts a conversion function to convert an integer + * into another data type. + */ + private def testIntegralDataTypes(testFunc: (Int => Any) => Unit): Unit = { + testFunc(_.toByte) + testFunc(_.toShort) + testFunc(identity) + testFunc(_.toLong) + } + + test("GetArrayItem") { + testIntegralDataTypes { convert => + val array = Literal.create(Seq("a", "b"), ArrayType(StringType)) + checkEvaluation(GetArrayItem(array, Literal(convert(1))), "b") + } + } + test("CreateStruct") { val row = InternalRow(1, 2, 3) val c1 = 'a.int.at(0).as("a") --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org