Repository: flink Updated Branches: refs/heads/master 427de663c -> ba46ab6b6
[FLINK-3736] [tableAPI] Move toRexNode logic into each expression's implementation. This closes #1870 Project: http://git-wip-us.apache.org/repos/asf/flink/repo Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/ba46ab6b Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/ba46ab6b Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/ba46ab6b Branch: refs/heads/master Commit: ba46ab6b659ffca60ea4a7b69f637622b9eb000c Parents: 427de66 Author: Yijie Shen <henry.yijies...@gmail.com> Authored: Tue Apr 12 01:33:15 2016 +0800 Committer: Fabian Hueske <fhue...@apache.org> Committed: Tue Apr 12 18:49:22 2016 +0200 ---------------------------------------------------------------------- .../api/table/expressions/Expression.scala | 19 ++- .../api/table/expressions/aggregations.scala | 34 ++++ .../api/table/expressions/arithmetic.scala | 49 +++++- .../flink/api/table/expressions/call.scala | 33 ++++ .../flink/api/table/expressions/cast.scala | 8 + .../api/table/expressions/comparison.scala | 35 +++- .../api/table/expressions/fieldExpression.scala | 15 +- .../flink/api/table/expressions/literals.scala | 7 + .../flink/api/table/expressions/logic.scala | 16 ++ .../api/table/plan/RexNodeTranslator.scala | 162 +------------------ .../org/apache/flink/api/table/table.scala | 26 ++- .../scala/table/test/AggregationsITCase.scala | 2 +- .../table/test/utils/ExpressionEvaluator.scala | 4 +- 13 files changed, 221 insertions(+), 189 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/flink/blob/ba46ab6b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/expressions/Expression.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/expressions/Expression.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/expressions/Expression.scala index cd278d0..6960a9f 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/expressions/Expression.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/expressions/Expression.scala @@ -18,30 +18,39 @@ package org.apache.flink.api.table.expressions import java.util.concurrent.atomic.AtomicInteger -import scala.language.postfixOps + +import org.apache.calcite.rex.RexNode +import org.apache.calcite.tools.RelBuilder abstract class Expression extends TreeNode[Expression] { self: Product => def name: String = Expression.freshName("expression") + + /** + * Convert Expression to its counterpart in Calcite, i.e. RexNode + */ + def toRexNode(implicit relBuilder: RelBuilder): RexNode = + throw new UnsupportedOperationException( + s"${this.getClass.getName} cannot be transformed to RexNode" + ) } -abstract class BinaryExpression() extends Expression { self: Product => +abstract class BinaryExpression extends Expression { self: Product => def left: Expression def right: Expression def children = Seq(left, right) } -abstract class UnaryExpression() extends Expression { self: Product => +abstract class UnaryExpression extends Expression { self: Product => def child: Expression def children = Seq(child) } -abstract class LeafExpression() extends Expression { self: Product => +abstract class LeafExpression extends Expression { self: Product => val children = Nil } case class NopExpression() extends LeafExpression { override val name = Expression.freshName("nop") - } object Expression { http://git-wip-us.apache.org/repos/asf/flink/blob/ba46ab6b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/expressions/aggregations.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/expressions/aggregations.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/expressions/aggregations.scala index d9d5fa8..8cd9dc3 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/expressions/aggregations.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/expressions/aggregations.scala @@ -17,26 +17,60 @@ */ package org.apache.flink.api.table.expressions +import org.apache.calcite.rex.RexNode +import org.apache.calcite.sql.fun.SqlStdOperatorTable +import org.apache.calcite.tools.RelBuilder +import org.apache.calcite.tools.RelBuilder.AggCall + abstract sealed class Aggregation extends UnaryExpression { self: Product => + override def toString = s"Aggregate($child)" + + override def toRexNode(implicit relBuilder: RelBuilder): RexNode = + throw new UnsupportedOperationException("Aggregate cannot be transformed to RexNode") + + /** + * Convert Aggregate to its counterpart in Calcite, i.e. AggCall + */ + def toAggCall(name: String)(implicit relBuilder: RelBuilder): AggCall } case class Sum(child: Expression) extends Aggregation { override def toString = s"($child).sum" + + override def toAggCall(name: String)(implicit relBuilder: RelBuilder): AggCall = { + relBuilder.aggregateCall(SqlStdOperatorTable.SUM, false, null, name, child.toRexNode) + } } case class Min(child: Expression) extends Aggregation { override def toString = s"($child).min" + + override def toAggCall(name: String)(implicit relBuilder: RelBuilder): AggCall = { + relBuilder.aggregateCall(SqlStdOperatorTable.MIN, false, null, name, child.toRexNode) + } } case class Max(child: Expression) extends Aggregation { override def toString = s"($child).max" + + override def toAggCall(name: String)(implicit relBuilder: RelBuilder): AggCall = { + relBuilder.aggregateCall(SqlStdOperatorTable.MAX, false, null, name, child.toRexNode) + } } case class Count(child: Expression) extends Aggregation { override def toString = s"($child).count" + + override def toAggCall(name: String)(implicit relBuilder: RelBuilder): AggCall = { + relBuilder.aggregateCall(SqlStdOperatorTable.COUNT, false, null, name, child.toRexNode) + } } case class Avg(child: Expression) extends Aggregation { override def toString = s"($child).avg" + + override def toAggCall(name: String)(implicit relBuilder: RelBuilder): AggCall = { + relBuilder.aggregateCall(SqlStdOperatorTable.AVG, false, null, name, child.toRexNode) + } } http://git-wip-us.apache.org/repos/asf/flink/blob/ba46ab6b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/expressions/arithmetic.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/expressions/arithmetic.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/expressions/arithmetic.scala index b0bfa86..ca67697 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/expressions/arithmetic.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/expressions/arithmetic.scala @@ -17,28 +17,75 @@ */ package org.apache.flink.api.table.expressions -abstract class BinaryArithmetic extends BinaryExpression { self: Product => } +import scala.collection.JavaConversions._ + +import org.apache.calcite.rex.RexNode +import org.apache.calcite.sql.`type`.SqlTypeName +import org.apache.calcite.sql.SqlOperator +import org.apache.calcite.sql.fun.SqlStdOperatorTable +import org.apache.calcite.tools.RelBuilder + +import org.apache.flink.api.common.typeinfo.BasicTypeInfo +import org.apache.flink.api.table.typeutils.TypeConverter + +abstract class BinaryArithmetic extends BinaryExpression { self: Product => + def sqlOperator: SqlOperator + + override def toRexNode(implicit relBuilder: RelBuilder): RexNode = { + relBuilder.call(sqlOperator, children.map(_.toRexNode)) + } +} case class Plus(left: Expression, right: Expression) extends BinaryArithmetic { override def toString = s"($left + $right)" + + val sqlOperator = SqlStdOperatorTable.PLUS + + override def toRexNode(implicit relBuilder: RelBuilder): RexNode = { + val l = left.toRexNode + val r = right.toRexNode + if(SqlTypeName.STRING_TYPES.contains(l.getType.getSqlTypeName)) { + val cast: RexNode = relBuilder.cast(r, + TypeConverter.typeInfoToSqlType(BasicTypeInfo.STRING_TYPE_INFO)) + relBuilder.call(SqlStdOperatorTable.PLUS, l, cast) + } else if(SqlTypeName.STRING_TYPES.contains(r.getType.getSqlTypeName)) { + val cast: RexNode = relBuilder.cast(l, + TypeConverter.typeInfoToSqlType(BasicTypeInfo.STRING_TYPE_INFO)) + relBuilder.call(SqlStdOperatorTable.PLUS, cast, r) + } else { + relBuilder.call(SqlStdOperatorTable.PLUS, l, r) + } + } } case class UnaryMinus(child: Expression) extends UnaryExpression { override def toString = s"-($child)" + + override def toRexNode(implicit relBuilder: RelBuilder): RexNode = { + relBuilder.call(SqlStdOperatorTable.UNARY_MINUS, child.toRexNode) + } } case class Minus(left: Expression, right: Expression) extends BinaryArithmetic { override def toString = s"($left - $right)" + + val sqlOperator = SqlStdOperatorTable.MINUS } case class Div(left: Expression, right: Expression) extends BinaryArithmetic { override def toString = s"($left / $right)" + + val sqlOperator = SqlStdOperatorTable.DIVIDE } case class Mul(left: Expression, right: Expression) extends BinaryArithmetic { override def toString = s"($left * $right)" + + val sqlOperator = SqlStdOperatorTable.MULTIPLY } case class Mod(left: Expression, right: Expression) extends BinaryArithmetic { override def toString = s"($left % $right)" + + val sqlOperator = SqlStdOperatorTable.MOD } http://git-wip-us.apache.org/repos/asf/flink/blob/ba46ab6b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/expressions/call.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/expressions/call.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/expressions/call.scala index 9f74414..280d213 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/expressions/call.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/expressions/call.scala @@ -17,6 +17,13 @@ */ package org.apache.flink.api.table.expressions +import scala.collection.JavaConversions._ + +import org.apache.calcite.rex.RexNode +import org.apache.calcite.sql.SqlOperator +import org.apache.calcite.sql.fun.SqlStdOperatorTable +import org.apache.calcite.tools.RelBuilder + /** * General expression for unresolved function calls. The function can be a built-in * scalar function or a user-defined scalar function. @@ -25,6 +32,12 @@ case class Call(functionName: String, args: Expression*) extends Expression { override def children: Seq[Expression] = args + override def toRexNode(implicit relBuilder: RelBuilder): RexNode = { + relBuilder.call( + BuiltInFunctionNames.toSqlOperator(functionName), + args.map(_.toRexNode)) + } + override def toString = s"\\$functionName(${args.mkString(", ")})" override def makeCopy(newArgs: Seq[AnyRef]): this.type = { @@ -54,6 +67,26 @@ object BuiltInFunctionNames { val POWER = "POWER" val LN = "LN" val ABS = "ABS" + + def toSqlOperator(name: String): SqlOperator = { + name match { + case BuiltInFunctionNames.SUBSTRING => SqlStdOperatorTable.SUBSTRING + case BuiltInFunctionNames.TRIM => SqlStdOperatorTable.TRIM + case BuiltInFunctionNames.CHAR_LENGTH => SqlStdOperatorTable.CHAR_LENGTH + case BuiltInFunctionNames.UPPER_CASE => SqlStdOperatorTable.UPPER + case BuiltInFunctionNames.LOWER_CASE => SqlStdOperatorTable.LOWER + case BuiltInFunctionNames.INIT_CAP => SqlStdOperatorTable.INITCAP + case BuiltInFunctionNames.LIKE => SqlStdOperatorTable.LIKE + case BuiltInFunctionNames.SIMILAR => SqlStdOperatorTable.SIMILAR_TO + case BuiltInFunctionNames.EXP => SqlStdOperatorTable.EXP + case BuiltInFunctionNames.LOG10 => SqlStdOperatorTable.LOG10 + case BuiltInFunctionNames.POWER => SqlStdOperatorTable.POWER + case BuiltInFunctionNames.LN => SqlStdOperatorTable.LN + case BuiltInFunctionNames.ABS => SqlStdOperatorTable.ABS + case BuiltInFunctionNames.MOD => SqlStdOperatorTable.MOD + case _ => ??? + } + } } /** http://git-wip-us.apache.org/repos/asf/flink/blob/ba46ab6b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/expressions/cast.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/expressions/cast.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/expressions/cast.scala index eb97d04..fdad1f6 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/expressions/cast.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/expressions/cast.scala @@ -17,12 +17,20 @@ */ package org.apache.flink.api.table.expressions +import org.apache.calcite.rex.RexNode +import org.apache.calcite.tools.RelBuilder + import org.apache.flink.api.common.typeinfo.TypeInformation +import org.apache.flink.api.table.typeutils.TypeConverter case class Cast(child: Expression, tpe: TypeInformation[_]) extends UnaryExpression { override def toString = s"$child.cast($tpe)" + override def toRexNode(implicit relBuilder: RelBuilder): RexNode = { + relBuilder.cast(child.toRexNode, TypeConverter.typeInfoToSqlType(tpe)) + } + override def makeCopy(anyRefs: Seq[AnyRef]): this.type = { val child: Expression = anyRefs.head.asInstanceOf[Expression] copy(child, tpe).asInstanceOf[this.type] http://git-wip-us.apache.org/repos/asf/flink/blob/ba46ab6b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/expressions/comparison.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/expressions/comparison.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/expressions/comparison.scala index d9e9198..124393c 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/expressions/comparison.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/expressions/comparison.scala @@ -17,36 +17,69 @@ */ package org.apache.flink.api.table.expressions -abstract class BinaryComparison extends BinaryExpression { self: Product => } +import scala.collection.JavaConversions._ + +import org.apache.calcite.rex.RexNode +import org.apache.calcite.sql.SqlOperator +import org.apache.calcite.sql.fun.SqlStdOperatorTable +import org.apache.calcite.tools.RelBuilder + +abstract class BinaryComparison extends BinaryExpression { self: Product => + def sqlOperator: SqlOperator + + override def toRexNode(implicit relBuilder: RelBuilder): RexNode = { + relBuilder.call(sqlOperator, children.map(_.toRexNode)) + } +} case class EqualTo(left: Expression, right: Expression) extends BinaryComparison { override def toString = s"$left === $right" + + val sqlOperator: SqlOperator = SqlStdOperatorTable.EQUALS } case class NotEqualTo(left: Expression, right: Expression) extends BinaryComparison { override def toString = s"$left !== $right" + + val sqlOperator: SqlOperator = SqlStdOperatorTable.NOT_EQUALS } case class GreaterThan(left: Expression, right: Expression) extends BinaryComparison { override def toString = s"$left > $right" + + val sqlOperator: SqlOperator = SqlStdOperatorTable.GREATER_THAN } case class GreaterThanOrEqual(left: Expression, right: Expression) extends BinaryComparison { override def toString = s"$left >= $right" + + val sqlOperator: SqlOperator = SqlStdOperatorTable.GREATER_THAN_OR_EQUAL } case class LessThan(left: Expression, right: Expression) extends BinaryComparison { override def toString = s"$left < $right" + + val sqlOperator: SqlOperator = SqlStdOperatorTable.LESS_THAN } case class LessThanOrEqual(left: Expression, right: Expression) extends BinaryComparison { override def toString = s"$left <= $right" + + val sqlOperator: SqlOperator = SqlStdOperatorTable.LESS_THAN_OR_EQUAL } case class IsNull(child: Expression) extends UnaryExpression { override def toString = s"($child).isNull" + + override def toRexNode(implicit relBuilder: RelBuilder): RexNode = { + relBuilder.isNull(child.toRexNode) + } } case class IsNotNull(child: Expression) extends UnaryExpression { override def toString = s"($child).isNotNull" + + override def toRexNode(implicit relBuilder: RelBuilder): RexNode = { + relBuilder.isNotNull(child.toRexNode) + } } http://git-wip-us.apache.org/repos/asf/flink/blob/ba46ab6b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/expressions/fieldExpression.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/expressions/fieldExpression.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/expressions/fieldExpression.scala index f3cb77e..82f7653 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/expressions/fieldExpression.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/expressions/fieldExpression.scala @@ -17,19 +17,28 @@ */ package org.apache.flink.api.table.expressions +import org.apache.calcite.rex.RexNode +import org.apache.calcite.tools.RelBuilder + case class UnresolvedFieldReference(override val name: String) extends LeafExpression { override def toString = "\"" + name -} -case class ResolvedFieldReference( - override val name: String) extends LeafExpression { + override def toRexNode(implicit relBuilder: RelBuilder): RexNode = { + relBuilder.field(name) + } +} +case class ResolvedFieldReference(override val name: String) extends LeafExpression { override def toString = s"'$name" } case class Naming(child: Expression, override val name: String) extends UnaryExpression { override def toString = s"$child as '$name" + override def toRexNode(implicit relBuilder: RelBuilder): RexNode = { + relBuilder.alias(child.toRexNode, name) + } + override def makeCopy(anyRefs: Seq[AnyRef]): this.type = { val child: Expression = anyRefs.head.asInstanceOf[Expression] copy(child, name).asInstanceOf[this.type] http://git-wip-us.apache.org/repos/asf/flink/blob/ba46ab6b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/expressions/literals.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/expressions/literals.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/expressions/literals.scala index 85956a2..efaa96d 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/expressions/literals.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/expressions/literals.scala @@ -18,6 +18,9 @@ package org.apache.flink.api.table.expressions import java.util.Date + +import org.apache.calcite.rex.RexNode +import org.apache.calcite.tools.RelBuilder import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, TypeInformation} import org.apache.flink.api.scala.table.ImplicitExpressionOperations @@ -41,4 +44,8 @@ case class Literal(value: Any, tpe: TypeInformation[_]) def typeInfo = tpe override def toString = s"$value" + + override def toRexNode(implicit relBuilder: RelBuilder): RexNode = { + relBuilder.literal(value) + } } http://git-wip-us.apache.org/repos/asf/flink/blob/ba46ab6b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/expressions/logic.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/expressions/logic.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/expressions/logic.scala index 3f9b5c2..99da371 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/expressions/logic.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/expressions/logic.scala @@ -17,6 +17,9 @@ */ package org.apache.flink.api.table.expressions +import org.apache.calcite.rex.RexNode +import org.apache.calcite.tools.RelBuilder + abstract class BinaryPredicate extends BinaryExpression { self: Product => } case class Not(child: Expression) extends UnaryExpression { @@ -24,17 +27,30 @@ case class Not(child: Expression) extends UnaryExpression { override val name = Expression.freshName("not-" + child.name) override def toString = s"!($child)" + + override def toRexNode(implicit relBuilder: RelBuilder): RexNode = { + relBuilder.not(child.toRexNode) + } } case class And(left: Expression, right: Expression) extends BinaryPredicate { + override def toString = s"$left && $right" override val name = Expression.freshName(left.name + "-and-" + right.name) + + override def toRexNode(implicit relBuilder: RelBuilder): RexNode = { + relBuilder.and(left.toRexNode, right.toRexNode) + } } case class Or(left: Expression, right: Expression) extends BinaryPredicate { + override def toString = s"$left || $right" override val name = Expression.freshName(left.name + "-or-" + right.name) + override def toRexNode(implicit relBuilder: RelBuilder): RexNode = { + relBuilder.or(left.toRexNode, right.toRexNode) + } } http://git-wip-us.apache.org/repos/asf/flink/blob/ba46ab6b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/RexNodeTranslator.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/RexNodeTranslator.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/RexNodeTranslator.scala index b50b74b..926e023 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/RexNodeTranslator.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/RexNodeTranslator.scala @@ -18,17 +18,10 @@ package org.apache.flink.api.table.plan -import org.apache.calcite.rex.RexNode -import org.apache.calcite.sql.SqlOperator -import org.apache.calcite.sql.`type`.SqlTypeName -import org.apache.calcite.sql.fun.SqlStdOperatorTable import org.apache.calcite.tools.RelBuilder import org.apache.calcite.tools.RelBuilder.AggCall -import org.apache.flink.api.common.typeinfo.BasicTypeInfo -import org.apache.flink.api.table.expressions._ -import org.apache.flink.api.table.typeutils.TypeConverter -import scala.collection.JavaConversions._ +import org.apache.flink.api.table.expressions._ object RexNodeTranslator { @@ -42,11 +35,11 @@ object RexNodeTranslator { exp match { case agg: Aggregation => val name = TranslationContext.getUniqueName - val aggCall = toAggCall(agg, name, relBuilder) + val aggCall = agg.toAggCall(name)(relBuilder) val fieldExp = new UnresolvedFieldReference(name) (fieldExp, List(aggCall)) case n@Naming(agg: Aggregation, name) => - val aggCall = toAggCall(agg, name, relBuilder) + val aggCall = agg.toAggCall(name)(relBuilder) val fieldExp = new UnresolvedFieldReference(name) (fieldExp, List(aggCall)) case l: LeafExpression => @@ -69,153 +62,4 @@ object RexNodeTranslator { s"Expression $e of type ${e.getClass} not supported yet") } } - - /** - * Translates a Table API expression into a Calcite RexNode. - */ - def toRexNode(exp: Expression, relBuilder: RelBuilder): RexNode = { - - exp match { - // Basic operators - case Literal(value, tpe) => - relBuilder.literal(value) - case ResolvedFieldReference(name) => - relBuilder.field(name) - case UnresolvedFieldReference(name) => - relBuilder.field(name) - case NopExpression() => - throw new IllegalArgumentException("NoOp expression encountered") - case Naming(child, name) => - val c = toRexNode(child, relBuilder) - relBuilder.alias(c, name) - case Cast(child, tpe) => - val c = toRexNode(child, relBuilder) - relBuilder.cast(c, TypeConverter.typeInfoToSqlType(tpe)) - case Not(child) => - val c = toRexNode(child, relBuilder) - relBuilder.not(c) - case Or(left, right) => - val l = toRexNode(left, relBuilder) - val r = toRexNode(right, relBuilder) - relBuilder.or(l, r) - case And(left, right) => - val l = toRexNode(left, relBuilder) - val r = toRexNode(right, relBuilder) - relBuilder.and(l, r) - case EqualTo(left, right) => - val l = toRexNode(left, relBuilder) - val r = toRexNode(right, relBuilder) - relBuilder.equals(l, r) - case NotEqualTo(left, right) => - val l = toRexNode(left, relBuilder) - val r = toRexNode(right, relBuilder) - relBuilder.not(relBuilder.equals(l, r)) - relBuilder.call(SqlStdOperatorTable.NOT_EQUALS, l, r) - case LessThan(left, right) => - val l = toRexNode(left, relBuilder) - val r = toRexNode(right, relBuilder) - relBuilder.call(SqlStdOperatorTable.LESS_THAN, l, r) - case LessThanOrEqual(left, right) => - val l = toRexNode(left, relBuilder) - val r = toRexNode(right, relBuilder) - relBuilder.call(SqlStdOperatorTable.LESS_THAN_OR_EQUAL, l, r) - case GreaterThan(left, right) => - val l = toRexNode(left, relBuilder) - val r = toRexNode(right, relBuilder) - relBuilder.call(SqlStdOperatorTable.GREATER_THAN, l, r) - case GreaterThanOrEqual(left, right) => - val l = toRexNode(left, relBuilder) - val r = toRexNode(right, relBuilder) - relBuilder.call(SqlStdOperatorTable.GREATER_THAN_OR_EQUAL, l, r) - case IsNull(child) => - val c = toRexNode(child, relBuilder) - relBuilder.isNull(c) - case IsNotNull(child) => - val c = toRexNode(child, relBuilder) - relBuilder.isNotNull(c) - case Plus(left, right) => - val l = toRexNode(left, relBuilder) - val r = toRexNode(right, relBuilder) - if(SqlTypeName.STRING_TYPES.contains(l.getType.getSqlTypeName)) { - val cast: RexNode = relBuilder.cast(r, - TypeConverter.typeInfoToSqlType(BasicTypeInfo.STRING_TYPE_INFO)) - relBuilder.call(SqlStdOperatorTable.PLUS, l, cast) - } else if(SqlTypeName.STRING_TYPES.contains(r.getType.getSqlTypeName)) { - val cast: RexNode = relBuilder.cast(l, - TypeConverter.typeInfoToSqlType(BasicTypeInfo.STRING_TYPE_INFO)) - relBuilder.call(SqlStdOperatorTable.PLUS, cast, r) - } else { - relBuilder.call(SqlStdOperatorTable.PLUS, l, r) - } - case Minus(left, right) => - val l = toRexNode(left, relBuilder) - val r = toRexNode(right, relBuilder) - relBuilder.call(SqlStdOperatorTable.MINUS, l, r) - case Mul(left, right) => - val l = toRexNode(left, relBuilder) - val r = toRexNode(right, relBuilder) - relBuilder.call(SqlStdOperatorTable.MULTIPLY, l, r) - case Div(left, right) => - val l = toRexNode(left, relBuilder) - val r = toRexNode(right, relBuilder) - relBuilder.call(SqlStdOperatorTable.DIVIDE, l, r) - case Mod(left, right) => - val l = toRexNode(left, relBuilder) - val r = toRexNode(right, relBuilder) - relBuilder.call(SqlStdOperatorTable.MOD, l, r) - case UnaryMinus(child) => - val c = toRexNode(child, relBuilder) - relBuilder.call(SqlStdOperatorTable.UNARY_MINUS, c) - - // Scalar functions - case Call(name, args@_*) => - val rexArgs = args.map(toRexNode(_, relBuilder)) - val sqlOperator = toSqlOperator(name) - relBuilder.call(sqlOperator, rexArgs) - - case a: Aggregation => - throw new IllegalArgumentException(s"Aggregation expression $a not allowed at this place") - case e@AnyRef => - throw new IllegalArgumentException( - s"Expression $e of type ${e.getClass} not supported yet") - } - } - - private def toAggCall(agg: Aggregation, name: String, relBuilder: RelBuilder): AggCall = { - - val rexNode = toRexNode(agg.child, relBuilder) - agg match { - case s: Sum => relBuilder.aggregateCall( - SqlStdOperatorTable.SUM, false, null, name, rexNode) - case m: Min => relBuilder.aggregateCall( - SqlStdOperatorTable.MIN, false, null, name, rexNode) - case m: Max => relBuilder.aggregateCall( - SqlStdOperatorTable.MAX, false, null, name, rexNode) - case c: Count => relBuilder.aggregateCall( - SqlStdOperatorTable.COUNT, false, null, name, rexNode) - case a: Avg => relBuilder.aggregateCall( - SqlStdOperatorTable.AVG, false, null, name, rexNode) - } - } - - private def toSqlOperator(name: String): SqlOperator = { - name match { - case BuiltInFunctionNames.SUBSTRING => SqlStdOperatorTable.SUBSTRING - case BuiltInFunctionNames.TRIM => SqlStdOperatorTable.TRIM - case BuiltInFunctionNames.CHAR_LENGTH => SqlStdOperatorTable.CHAR_LENGTH - case BuiltInFunctionNames.UPPER_CASE => SqlStdOperatorTable.UPPER - case BuiltInFunctionNames.LOWER_CASE => SqlStdOperatorTable.LOWER - case BuiltInFunctionNames.INIT_CAP => SqlStdOperatorTable.INITCAP - case BuiltInFunctionNames.LIKE => SqlStdOperatorTable.LIKE - case BuiltInFunctionNames.SIMILAR => SqlStdOperatorTable.SIMILAR_TO - case BuiltInFunctionNames.EXP => SqlStdOperatorTable.EXP - case BuiltInFunctionNames.LOG10 => SqlStdOperatorTable.LOG10 - case BuiltInFunctionNames.POWER => SqlStdOperatorTable.POWER - case BuiltInFunctionNames.LN => SqlStdOperatorTable.LN - case BuiltInFunctionNames.ABS => SqlStdOperatorTable.ABS - case BuiltInFunctionNames.MOD => SqlStdOperatorTable.MOD - case _ => ??? - } - } - } http://git-wip-us.apache.org/repos/asf/flink/blob/ba46ab6b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/table.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/table.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/table.scala index 53c3b4a..7b40c57 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/table.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/table.scala @@ -30,7 +30,7 @@ import org.apache.calcite.util.NlsString import org.apache.flink.api.java.io.DiscardingOutputFormat import org.apache.flink.api.table.explain.PlanJsonParser import org.apache.flink.api.table.plan.{PlanGenException, RexNodeTranslator} -import RexNodeTranslator.{toRexNode, extractAggCalls} +import RexNodeTranslator.extractAggCalls import org.apache.flink.api.table.expressions.{ExpressionParser, Naming, UnresolvedFieldReference, Expression} import org.apache.flink.api.scala._ @@ -96,8 +96,7 @@ class Table( .map(extractAggCalls(_, relBuilder)).toList // get aggregation calls - val aggCalls: List[AggCall] = extractedAggCalls - .map(_._2).reduce( (x,y) => x ::: y) + val aggCalls: List[AggCall] = extractedAggCalls.flatMap(_._2) // apply aggregations if (aggCalls.nonEmpty) { @@ -106,9 +105,7 @@ class Table( } // get selection expressions - val exprs: List[RexNode] = extractedAggCalls - .map(_._1) - .map(toRexNode(_, relBuilder)) + val exprs: List[RexNode] = extractedAggCalls.map(_._1.toRexNode(relBuilder)) relBuilder.project(exprs.toIterable.asJava) val projected = relBuilder.build() @@ -170,7 +167,7 @@ class Table( relBuilder.push(relNode) - val exprs = (renamings ++ remaining).map(toRexNode(_, relBuilder)) + val exprs = (renamings ++ remaining).map(_.toRexNode(relBuilder)) new Table(createRenamingProject(exprs), relBuilder) } @@ -203,8 +200,7 @@ class Table( def filter(predicate: Expression): Table = { relBuilder.push(relNode) - val pred = toRexNode(predicate, relBuilder) - relBuilder.filter(pred) + relBuilder.filter(predicate.toRexNode(relBuilder)) new Table(relBuilder.build(), relBuilder) } @@ -264,7 +260,7 @@ class Table( def groupBy(fields: Expression*): GroupedTable = { relBuilder.push(relNode) - val groupExpr = fields.map(toRexNode(_, relBuilder)).toIterable.asJava + val groupExpr = fields.map(_.toRexNode(relBuilder)).toIterable.asJava val groupKey = relBuilder.groupKey(groupExpr) new GroupedTable(relBuilder.build(), relBuilder, groupKey) @@ -450,19 +446,15 @@ class GroupedTable( .map(extractAggCalls(_, relBuilder)).toList // get aggregation calls - val aggCalls: List[AggCall] = extractedAggCalls - .map(_._2).reduce( (x,y) => x ::: y) + val aggCalls: List[AggCall] = extractedAggCalls.flatMap(_._2) // apply aggregations relBuilder.aggregate(groupKey, aggCalls.toIterable.asJava) // get selection expressions val exprs: List[RexNode] = try { - extractedAggCalls - .map(_._1) - .map(toRexNode(_, relBuilder)) - } - catch { + extractedAggCalls.map(_._1.toRexNode(relBuilder)) + } catch { case iae: IllegalArgumentException => throw new IllegalArgumentException( "Only grouping fields and aggregations allowed after groupBy.", iae) http://git-wip-us.apache.org/repos/asf/flink/blob/ba46ab6b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/table/test/AggregationsITCase.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/table/test/AggregationsITCase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/table/test/AggregationsITCase.scala index 0741db8..abf2735 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/table/test/AggregationsITCase.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/table/test/AggregationsITCase.scala @@ -133,7 +133,7 @@ class AggregationsITCase(mode: TestExecutionMode) extends MultipleProgramsTestBa t.collect() } - @Test(expected = classOf[IllegalArgumentException]) + @Test(expected = classOf[UnsupportedOperationException]) def testNoNestedAggregations(): Unit = { val env = ExecutionEnvironment.getExecutionEnvironment http://git-wip-us.apache.org/repos/asf/flink/blob/ba46ab6b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/test/utils/ExpressionEvaluator.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/test/utils/ExpressionEvaluator.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/test/utils/ExpressionEvaluator.scala index a52bbbd..48dea56 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/test/utils/ExpressionEvaluator.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/test/utils/ExpressionEvaluator.scala @@ -29,7 +29,7 @@ import org.apache.flink.api.java.DataSet import org.apache.flink.api.table.TableConfig import org.apache.flink.api.table.codegen.{CodeGenerator, GeneratedFunction} import org.apache.flink.api.table.expressions.Expression -import org.apache.flink.api.table.plan.{RexNodeTranslator, TranslationContext} +import org.apache.flink.api.table.plan.TranslationContext import org.apache.flink.api.table.plan.schema.DataSetTable import org.apache.flink.api.table.runtime.FunctionCompiler import org.mockito.Mockito._ @@ -78,7 +78,7 @@ object ExpressionEvaluator { def evaluate(data: Any, typeInfo: TypeInformation[Any], expr: Expression): String = { val relBuilder = prepareTable(typeInfo)._2 - evaluate(data, typeInfo, relBuilder, RexNodeTranslator.toRexNode(expr, relBuilder)) + evaluate(data, typeInfo, relBuilder, expr.toRexNode(relBuilder)) } def evaluate(