http://git-wip-us.apache.org/repos/asf/flink/blob/495f104b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/fieldExpression.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/fieldExpression.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/fieldExpression.scala index 0885929..362d846 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/fieldExpression.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/fieldExpression.scala @@ -19,8 +19,9 @@ package org.apache.flink.table.expressions import org.apache.calcite.rex.RexNode import org.apache.calcite.tools.RelBuilder -import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, TypeInformation} +import org.apache.flink.api.common.typeinfo.{SqlTimeTypeInfo, TypeInformation} import org.apache.flink.table.api.{UnresolvedException, ValidationException} +import org.apache.flink.table.typeutils.TimeIndicatorTypeInfo import org.apache.flink.table.validate.{ValidationFailure, ValidationResult, ValidationSuccess} trait NamedExpression extends Expression { @@ -116,24 +117,6 @@ case class UnresolvedAlias(child: Expression) extends UnaryExpression with Named override private[flink] lazy val valid = false } -case class RowtimeAttribute() extends Attribute { - override private[flink] def withName(newName: String): Attribute = { - if (newName == "rowtime") { - this - } else { - throw new ValidationException("Cannot rename streaming rowtime attribute.") - } - } - - override private[flink] def name: String = "rowtime" - - override private[flink] def toRexNode(implicit relBuilder: RelBuilder): RexNode = { - throw new UnsupportedOperationException("A rowtime attribute can not be used solely.") - } - - override private[flink] def resultType: TypeInformation[_] = BasicTypeInfo.LONG_TYPE_INFO -} - case class WindowReference(name: String) extends Attribute { override private[flink] def toRexNode(implicit relBuilder: RelBuilder): RexNode = @@ -150,3 +133,30 @@ case class WindowReference(name: String) extends Attribute { } } } + +abstract class TimeAttribute(val expression: Expression) + extends UnaryExpression + with NamedExpression { + + override private[flink] def child: Expression = expression + + override private[flink] def name: String = expression match { + case UnresolvedFieldReference(name) => name + case _ => throw new ValidationException("Unresolved field reference expected.") + } + + override private[flink] def toAttribute: Attribute = + throw new UnsupportedOperationException("Time attribute can not be used solely.") +} + +case class RowtimeAttribute(expr: Expression) extends TimeAttribute(expr) { + + override private[flink] def resultType: TypeInformation[_] = + TimeIndicatorTypeInfo.ROWTIME_INDICATOR +} + +case class ProctimeAttribute(expr: Expression) extends TimeAttribute(expr) { + + override private[flink] def resultType: TypeInformation[_] = + TimeIndicatorTypeInfo.PROCTIME_INDICATOR +}
http://git-wip-us.apache.org/repos/asf/flink/blob/495f104b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/TimeMaterializationSqlFunction.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/TimeMaterializationSqlFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/TimeMaterializationSqlFunction.scala new file mode 100644 index 0000000..d875026 --- /dev/null +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/TimeMaterializationSqlFunction.scala @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.table.functions + +import org.apache.calcite.sql._ +import org.apache.calcite.sql.`type`._ +import org.apache.calcite.sql.validate.SqlMonotonicity + +/** + * Function that materializes a time attribute to the metadata timestamp. After materialization + * the result can be used in regular arithmetical calculations. + */ +object TimeMaterializationSqlFunction + extends SqlFunction( + "TIME_MATERIALIZATION", + SqlKind.OTHER_FUNCTION, + ReturnTypes.explicit(SqlTypeName.TIMESTAMP), + InferTypes.RETURN_TYPE, + OperandTypes.family(SqlTypeFamily.TIMESTAMP), + SqlFunctionCategory.SYSTEM) { + + override def getSyntax: SqlSyntax = SqlSyntax.FUNCTION + + override def getMonotonicity(call: SqlOperatorBinding): SqlMonotonicity = + SqlMonotonicity.INCREASING +} http://git-wip-us.apache.org/repos/asf/flink/blob/495f104b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/TimeModeIndicatorFunctions.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/TimeModeIndicatorFunctions.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/TimeModeIndicatorFunctions.scala deleted file mode 100644 index 3ddcbdc..0000000 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/TimeModeIndicatorFunctions.scala +++ /dev/null @@ -1,137 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.flink.table.functions - -import java.nio.charset.Charset -import java.util - -import org.apache.calcite.rel.`type`._ -import org.apache.calcite.sql._ -import org.apache.calcite.sql.`type`.{OperandTypes, ReturnTypes, SqlTypeFamily, SqlTypeName} -import org.apache.calcite.sql.validate.SqlMonotonicity -import org.apache.calcite.tools.RelBuilder -import org.apache.flink.api.common.typeinfo.SqlTimeTypeInfo -import org.apache.flink.table.api.TableException -import org.apache.flink.table.expressions.LeafExpression - -object EventTimeExtractor extends SqlFunction("ROWTIME", SqlKind.OTHER_FUNCTION, - ReturnTypes.explicit(TimeModeTypes.ROWTIME), null, OperandTypes.NILADIC, - SqlFunctionCategory.SYSTEM) { - override def getSyntax: SqlSyntax = SqlSyntax.FUNCTION - - override def getMonotonicity(call: SqlOperatorBinding): SqlMonotonicity = - SqlMonotonicity.INCREASING -} - -object ProcTimeExtractor extends SqlFunction("PROCTIME", SqlKind.OTHER_FUNCTION, - ReturnTypes.explicit(TimeModeTypes.PROCTIME), null, OperandTypes.NILADIC, - SqlFunctionCategory.SYSTEM) { - override def getSyntax: SqlSyntax = SqlSyntax.FUNCTION - - override def getMonotonicity(call: SqlOperatorBinding): SqlMonotonicity = - SqlMonotonicity.INCREASING -} - -abstract class TimeIndicator extends LeafExpression { - /** - * Returns the [[org.apache.flink.api.common.typeinfo.TypeInformation]] - * for evaluating this expression. - * It is sometimes not available until the expression is valid. - */ - override private[flink] def resultType = SqlTimeTypeInfo.TIMESTAMP - - /** - * Convert Expression to its counterpart in Calcite, i.e. RexNode - */ - override private[flink] def toRexNode(implicit relBuilder: RelBuilder) = - throw new TableException("indicator functions (e.g. proctime() and rowtime()" + - " are not executable. Please check your expressions.") -} - -case class RowTime() extends TimeIndicator -case class ProcTime() extends TimeIndicator - -object TimeModeTypes { - - // indicator data type for row time (event time) - val ROWTIME = new RowTimeType - // indicator data type for processing time - val PROCTIME = new ProcTimeType - -} - -class RowTimeType extends TimeModeType { - - override def toString(): String = "ROWTIME" - override def getFullTypeString: String = "ROWTIME_INDICATOR" -} - -class ProcTimeType extends TimeModeType { - - override def toString(): String = "PROCTIME" - override def getFullTypeString: String = "PROCTIME_INDICATOR" -} - -abstract class TimeModeType extends RelDataType { - - override def getComparability: RelDataTypeComparability = RelDataTypeComparability.NONE - - override def isStruct: Boolean = false - - override def getFieldList: util.List[RelDataTypeField] = null - - override def getFieldNames: util.List[String] = null - - override def getFieldCount: Int = 0 - - override def getStructKind: StructKind = StructKind.NONE - - override def getField( - fieldName: String, - caseSensitive: Boolean, - elideRecord: Boolean): RelDataTypeField = null - - override def isNullable: Boolean = false - - override def getComponentType: RelDataType = null - - override def getKeyType: RelDataType = null - - override def getValueType: RelDataType = null - - override def getCharset: Charset = null - - override def getCollation: SqlCollation = null - - override def getIntervalQualifier: SqlIntervalQualifier = null - - override def getPrecision: Int = -1 - - override def getScale: Int = -1 - - override def getSqlTypeName: SqlTypeName = SqlTypeName.TIMESTAMP - - override def getSqlIdentifier: SqlIdentifier = null - - override def getFamily: RelDataTypeFamily = SqlTypeFamily.NUMERIC - - override def getPrecedenceList: RelDataTypePrecedenceList = ??? - - override def isDynamicStruct: Boolean = false - -} http://git-wip-us.apache.org/repos/asf/flink/blob/495f104b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/ProjectionTranslator.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/ProjectionTranslator.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/ProjectionTranslator.scala index d26cdcf..98a7e63 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/ProjectionTranslator.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/ProjectionTranslator.scala @@ -19,9 +19,8 @@ package org.apache.flink.table.plan import org.apache.flink.api.common.typeutils.CompositeType -import org.apache.flink.table.api.{OverWindow, StreamTableEnvironment, TableEnvironment} +import org.apache.flink.table.api.{OverWindow, TableEnvironment} import org.apache.flink.table.expressions._ -import org.apache.flink.table.functions.{ProcTime, RowTime} import org.apache.flink.table.plan.logical.{LogicalNode, Project} import scala.collection.mutable @@ -231,28 +230,12 @@ object ProjectionTranslator { val overWindow = overWindows.find(_.alias.equals(unresolvedCall.alias)) if (overWindow.isDefined) { - if (tEnv.isInstanceOf[StreamTableEnvironment]) { - val timeIndicator = overWindow.get.orderBy match { - case u: UnresolvedFieldReference if u.name.toLowerCase == "rowtime" => - RowTime() - case u: UnresolvedFieldReference if u.name.toLowerCase == "proctime" => - ProcTime() - case e: Expression => e - } - OverCall( - unresolvedCall.agg, - overWindow.get.partitionBy, - timeIndicator, - overWindow.get.preceding, - overWindow.get.following) - } else { - OverCall( - unresolvedCall.agg, - overWindow.get.partitionBy, - overWindow.get.orderBy, - overWindow.get.preceding, - overWindow.get.following) - } + OverCall( + unresolvedCall.agg, + overWindow.get.partitionBy, + overWindow.get.orderBy, + overWindow.get.preceding, + overWindow.get.following) } else { unresolvedCall } http://git-wip-us.apache.org/repos/asf/flink/blob/495f104b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/logical/LogicalWindow.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/logical/LogicalWindow.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/logical/LogicalWindow.scala index 1884e54..92dc501 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/logical/LogicalWindow.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/logical/LogicalWindow.scala @@ -22,14 +22,24 @@ import org.apache.flink.table.api.TableEnvironment import org.apache.flink.table.expressions.{Expression, WindowReference} import org.apache.flink.table.validate.{ValidationFailure, ValidationResult, ValidationSuccess} -abstract class LogicalWindow(val alias: Expression) extends Resolvable[LogicalWindow] { +/** + * Logical super class for all types of windows (group-windows and row-windows). + * + * @param aliasAttribute window alias + * @param timeAttribute time field indicating event-time or processing-time + */ +abstract class LogicalWindow( + val aliasAttribute: Expression, + val timeAttribute: Expression) + extends Resolvable[LogicalWindow] { def resolveExpressions(resolver: (Expression) => Expression): LogicalWindow = this - def validate(tableEnv: TableEnvironment): ValidationResult = alias match { + def validate(tableEnv: TableEnvironment): ValidationResult = aliasAttribute match { case WindowReference(_) => ValidationSuccess case _ => ValidationFailure("Window reference for window expected.") } override def toString: String = getClass.getSimpleName + } http://git-wip-us.apache.org/repos/asf/flink/blob/495f104b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/logical/groupWindows.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/logical/groupWindows.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/logical/groupWindows.scala index 576756d..3e5de28 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/logical/groupWindows.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/logical/groupWindows.scala @@ -18,259 +18,165 @@ package org.apache.flink.table.plan.logical -import org.apache.flink.api.common.typeinfo.BasicTypeInfo import org.apache.flink.table.api.{BatchTableEnvironment, StreamTableEnvironment, TableEnvironment} +import org.apache.flink.table.expressions.ExpressionUtils.{isRowCountLiteral, isRowtimeAttribute, isTimeAttribute, isTimeIntervalLiteral} import org.apache.flink.table.expressions._ -import org.apache.flink.table.typeutils.{RowIntervalTypeInfo, TimeIntervalTypeInfo, TypeCoercion} +import org.apache.flink.table.typeutils.TypeCheckUtils.isTimePoint import org.apache.flink.table.validate.{ValidationFailure, ValidationResult, ValidationSuccess} -abstract class EventTimeGroupWindow( - alias: Expression, - time: Expression) - extends LogicalWindow(alias) { - - override def validate(tableEnv: TableEnvironment): ValidationResult = { - val valid = super.validate(tableEnv) - if (valid.isFailure) { - return valid - } - - tableEnv match { - case _: StreamTableEnvironment => - time match { - case RowtimeAttribute() => - ValidationSuccess - case _ => - ValidationFailure("Event-time window expects a 'rowtime' time field.") - } - case _: BatchTableEnvironment => - if (!TypeCoercion.canCast(time.resultType, BasicTypeInfo.LONG_TYPE_INFO)) { - ValidationFailure(s"Event-time window expects a time field that can be safely cast " + - s"to Long, but is ${time.resultType}") - } else { - ValidationSuccess - } - } - - } -} - -abstract class ProcessingTimeGroupWindow(alias: Expression) extends LogicalWindow(alias) { - override def validate(tableEnv: TableEnvironment): ValidationResult = { - val valid = super.validate(tableEnv) - if (valid.isFailure) { - return valid - } - - tableEnv match { - case b: BatchTableEnvironment => ValidationFailure( - "Window on batch must declare a time attribute over which the query is evaluated.") - case _ => - ValidationSuccess - } - } -} - // ------------------------------------------------------------------------------------------------ // Tumbling group windows // ------------------------------------------------------------------------------------------------ -object TumblingGroupWindow { - def validate(tableEnv: TableEnvironment, size: Expression): ValidationResult = size match { - case Literal(_, TimeIntervalTypeInfo.INTERVAL_MILLIS) => - ValidationSuccess - case Literal(_, RowIntervalTypeInfo.INTERVAL_ROWS) => - ValidationSuccess - case _ => - ValidationFailure("Tumbling window expects size literal of type Interval of Milliseconds " + - "or Interval of Rows.") - } -} - -case class ProcessingTimeTumblingGroupWindow( - override val alias: Expression, - size: Expression) - extends ProcessingTimeGroupWindow(alias) { - - override def resolveExpressions(resolve: (Expression) => Expression): LogicalWindow = - ProcessingTimeTumblingGroupWindow( - resolve(alias), - resolve(size)) - - override def validate(tableEnv: TableEnvironment): ValidationResult = - super.validate(tableEnv).orElse(TumblingGroupWindow.validate(tableEnv, size)) - - override def toString: String = s"ProcessingTimeTumblingGroupWindow($alias, $size)" -} - -case class EventTimeTumblingGroupWindow( - override val alias: Expression, +case class TumblingGroupWindow( + alias: Expression, timeField: Expression, size: Expression) - extends EventTimeGroupWindow( + extends LogicalWindow( alias, timeField) { override def resolveExpressions(resolve: (Expression) => Expression): LogicalWindow = - EventTimeTumblingGroupWindow( + TumblingGroupWindow( resolve(alias), resolve(timeField), resolve(size)) override def validate(tableEnv: TableEnvironment): ValidationResult = - super.validate(tableEnv) - .orElse(TumblingGroupWindow.validate(tableEnv, size)) - .orElse(size match { - case Literal(_, RowIntervalTypeInfo.INTERVAL_ROWS) - if tableEnv.isInstanceOf[StreamTableEnvironment] => + super.validate(tableEnv).orElse( + tableEnv match { + + // check size + case _ if !isTimeIntervalLiteral(size) && !isRowCountLiteral(size) => + ValidationFailure( + "Tumbling window expects size literal of type Interval of Milliseconds " + + "or Interval of Rows.") + + // check time attribute + case _: StreamTableEnvironment if !isTimeAttribute(timeField) => + ValidationFailure( + "Tumbling window expects a time attribute for grouping in a stream environment.") + case _: BatchTableEnvironment if isTimePoint(size.resultType) => + ValidationFailure( + "Tumbling window expects a time attribute for grouping in a stream environment.") + + // check row intervals on event-time + case _: StreamTableEnvironment + if isRowCountLiteral(size) && isRowtimeAttribute(timeField) => ValidationFailure( "Event-time grouping windows on row intervals in a stream environment " + "are currently not supported.") + case _ => ValidationSuccess - }) + } + ) - override def toString: String = s"EventTimeTumblingGroupWindow($alias, $timeField, $size)" + override def toString: String = s"TumblingGroupWindow($alias, $timeField, $size)" } // ------------------------------------------------------------------------------------------------ // Sliding group windows // ------------------------------------------------------------------------------------------------ -object SlidingGroupWindow { - def validate( - tableEnv: TableEnvironment, - size: Expression, - slide: Expression) - : ValidationResult = { - - val checkedSize = size match { - case Literal(_, TimeIntervalTypeInfo.INTERVAL_MILLIS) => - ValidationSuccess - case Literal(_, RowIntervalTypeInfo.INTERVAL_ROWS) => - ValidationSuccess - case _ => - ValidationFailure("Sliding window expects size literal of type Interval of " + - "Milliseconds or Interval of Rows.") - } - - val checkedSlide = slide match { - case Literal(_, TimeIntervalTypeInfo.INTERVAL_MILLIS) => - ValidationSuccess - case Literal(_, RowIntervalTypeInfo.INTERVAL_ROWS) => - ValidationSuccess - case _ => - ValidationFailure("Sliding window expects slide literal of type Interval of " + - "Milliseconds or Interval of Rows.") - } - - checkedSize - .orElse(checkedSlide) - .orElse { - if (size.resultType != slide.resultType) { - ValidationFailure("Sliding window expects same type of size and slide.") - } else { - ValidationSuccess - } - } - } -} - -case class ProcessingTimeSlidingGroupWindow( - override val alias: Expression, +case class SlidingGroupWindow( + alias: Expression, + timeField: Expression, size: Expression, slide: Expression) - extends ProcessingTimeGroupWindow(alias) { + extends LogicalWindow( + alias, + timeField) { override def resolveExpressions(resolve: (Expression) => Expression): LogicalWindow = - ProcessingTimeSlidingGroupWindow( + SlidingGroupWindow( resolve(alias), + resolve(timeField), resolve(size), resolve(slide)) override def validate(tableEnv: TableEnvironment): ValidationResult = - super.validate(tableEnv).orElse(SlidingGroupWindow.validate(tableEnv, size, slide)) + super.validate(tableEnv).orElse( + tableEnv match { - override def toString: String = s"ProcessingTimeSlidingGroupWindow($alias, $size, $slide)" -} + // check size + case _ if !isTimeIntervalLiteral(size) && !isRowCountLiteral(size) => + ValidationFailure( + "Sliding window expects size literal of type Interval of Milliseconds " + + "or Interval of Rows.") -case class EventTimeSlidingGroupWindow( - override val alias: Expression, - timeField: Expression, - size: Expression, - slide: Expression) - extends EventTimeGroupWindow(alias, timeField) { + // check slide + case _ if !isTimeIntervalLiteral(slide) && !isRowCountLiteral(slide) => + ValidationFailure( + "Sliding window expects slide literal of type Interval of Milliseconds " + + "or Interval of Rows.") - override def resolveExpressions(resolve: (Expression) => Expression): LogicalWindow = - EventTimeSlidingGroupWindow( - resolve(alias), - resolve(timeField), - resolve(size), - resolve(slide)) + // check same type of intervals + case _ if isTimeIntervalLiteral(size) != isTimeIntervalLiteral(slide) => + ValidationFailure("Sliding window expects same type of size and slide.") - override def validate(tableEnv: TableEnvironment): ValidationResult = - super.validate(tableEnv) - .orElse(SlidingGroupWindow.validate(tableEnv, size, slide)) - .orElse(size match { - case Literal(_, RowIntervalTypeInfo.INTERVAL_ROWS) - if tableEnv.isInstanceOf[StreamTableEnvironment] => + // check time attribute + case _: StreamTableEnvironment if !isTimeAttribute(timeField) => + ValidationFailure( + "Sliding window expects a time attribute for grouping in a stream environment.") + case _: BatchTableEnvironment if isTimePoint(size.resultType) => + ValidationFailure( + "Sliding window expects a time attribute for grouping in a stream environment.") + + // check row intervals on event-time + case _: StreamTableEnvironment + if isRowCountLiteral(size) && isRowtimeAttribute(timeField) => ValidationFailure( "Event-time grouping windows on row intervals in a stream environment " + "are currently not supported.") + case _ => ValidationSuccess - }) + } + ) - override def toString: String = s"EventTimeSlidingGroupWindow($alias, $timeField, $size, $slide)" + override def toString: String = s"SlidingGroupWindow($alias, $timeField, $size, $slide)" } // ------------------------------------------------------------------------------------------------ // Session group windows // ------------------------------------------------------------------------------------------------ -object SessionGroupWindow { - - def validate(tableEnv: TableEnvironment, gap: Expression): ValidationResult = gap match { - case Literal(timeInterval: Long, TimeIntervalTypeInfo.INTERVAL_MILLIS) => - ValidationSuccess - case _ => - ValidationFailure( - "Session window expects gap literal of type Interval of Milliseconds.") - } -} - -case class ProcessingTimeSessionGroupWindow( - override val alias: Expression, - gap: Expression) - extends ProcessingTimeGroupWindow(alias) { - - override def resolveExpressions(resolve: (Expression) => Expression): LogicalWindow = - ProcessingTimeSessionGroupWindow( - resolve(alias), - resolve(gap)) - - override def validate(tableEnv: TableEnvironment): ValidationResult = - super.validate(tableEnv).orElse(SessionGroupWindow.validate(tableEnv, gap)) - - override def toString: String = s"ProcessingTimeSessionGroupWindow($alias, $gap)" -} - -case class EventTimeSessionGroupWindow( - override val alias: Expression, +case class SessionGroupWindow( + alias: Expression, timeField: Expression, gap: Expression) - extends EventTimeGroupWindow( + extends LogicalWindow( alias, timeField) { override def resolveExpressions(resolve: (Expression) => Expression): LogicalWindow = - EventTimeSessionGroupWindow( + SessionGroupWindow( resolve(alias), resolve(timeField), resolve(gap)) override def validate(tableEnv: TableEnvironment): ValidationResult = - super.validate(tableEnv).orElse(SessionGroupWindow.validate(tableEnv, gap)) + super.validate(tableEnv).orElse( + tableEnv match { + + // check size + case _ if !isTimeIntervalLiteral(gap) => + ValidationFailure( + "Session window expects size literal of type Interval of Milliseconds.") + + // check time attribute + case _: StreamTableEnvironment if !isTimeAttribute(timeField) => + ValidationFailure( + "Session window expects a time attribute for grouping in a stream environment.") + case _: BatchTableEnvironment if isTimePoint(gap.resultType) => + ValidationFailure( + "Session window expects a time attribute for grouping in a stream environment.") + + case _ => + ValidationSuccess + } + ) - override def toString: String = s"EventTimeSessionGroupWindow($alias, $timeField, $gap)" + override def toString: String = s"SessionGroupWindow($alias, $timeField, $gap)" } http://git-wip-us.apache.org/repos/asf/flink/blob/495f104b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/logical/operators.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/logical/operators.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/logical/operators.scala index 5f2394c..3839145 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/logical/operators.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/logical/operators.scala @@ -70,8 +70,6 @@ case class Project(projectList: Seq[NamedExpression], child: LogicalNode) extend def checkName(name: String): Unit = { if (names.contains(name)) { failValidation(s"Duplicate field name $name.") - } else if (tableEnv.isInstanceOf[StreamTableEnvironment] && name == "rowtime") { - failValidation("'rowtime' cannot be used as field name in a streaming environment.") } else { names.add(name) } @@ -112,10 +110,6 @@ case class AliasNode(aliasList: Seq[Expression], child: LogicalNode) extends Una failValidation("Alias only accept name expressions as arguments") } else if (!aliasList.forall(_.asInstanceOf[UnresolvedFieldReference].name != "*")) { failValidation("Alias can not accept '*' as name") - } else if (tableEnv.isInstanceOf[StreamTableEnvironment] && !aliasList.forall { - case UnresolvedFieldReference(name) => name != "rowtime" - }) { - failValidation("'rowtime' cannot be used as field name in a streaming environment.") } else { val names = aliasList.map(_.asInstanceOf[UnresolvedFieldReference].name) val input = child.output @@ -561,26 +555,20 @@ case class WindowAggregate( override def resolveReference( tableEnv: TableEnvironment, name: String) - : Option[NamedExpression] = tableEnv match { - // resolve reference to rowtime attribute in a streaming environment - case _: StreamTableEnvironment if name == "rowtime" => - Some(RowtimeAttribute()) - case _ => - window.alias match { - // resolve reference to this window's alias - case UnresolvedFieldReference(alias) if name == alias => - // check if reference can already be resolved by input fields - val found = super.resolveReference(tableEnv, name) - if (found.isDefined) { - failValidation(s"Reference $name is ambiguous.") - } else { - Some(WindowReference(name)) - } - case _ => - // resolve references as usual - super.resolveReference(tableEnv, name) - } - } + : Option[NamedExpression] = window.aliasAttribute match { + // resolve reference to this window's name + case UnresolvedFieldReference(alias) if name == alias => + // check if reference can already be resolved by input fields + val found = super.resolveReference(tableEnv, name) + if (found.isDefined) { + failValidation(s"Reference $name is ambiguous.") + } else { + Some(WindowReference(name)) + } + case _ => + // resolve references as usual + super.resolveReference(tableEnv, name) + } override protected[logical] def construct(relBuilder: RelBuilder): RelBuilder = { val flinkRelBuilder = relBuilder.asInstanceOf[FlinkRelBuilder] http://git-wip-us.apache.org/repos/asf/flink/blob/495f104b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/CommonCalc.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/CommonCalc.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/CommonCalc.scala index 96a7470..5c35129 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/CommonCalc.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/CommonCalc.scala @@ -19,13 +19,12 @@ package org.apache.flink.table.plan.nodes import org.apache.calcite.plan.{RelOptCost, RelOptPlanner} -import org.apache.calcite.rel.`type`.RelDataType import org.apache.calcite.rex._ import org.apache.flink.api.common.functions.{FlatMapFunction, RichFlatMapFunction} -import org.apache.flink.api.common.typeinfo.TypeInformation import org.apache.flink.table.api.TableConfig import org.apache.flink.table.calcite.FlinkTypeFactory import org.apache.flink.table.codegen.{CodeGenerator, GeneratedFunction} +import org.apache.flink.table.plan.schema.RowSchema import org.apache.flink.table.runtime.FlatMapRunner import org.apache.flink.types.Row @@ -35,21 +34,30 @@ import scala.collection.JavaConverters._ trait CommonCalc { private[flink] def functionBody( - generator: CodeGenerator, - inputType: TypeInformation[Row], - rowType: RelDataType, - calcProgram: RexProgram, - config: TableConfig) + generator: CodeGenerator, + inputSchema: RowSchema, + returnSchema: RowSchema, + calcProgram: RexProgram, + config: TableConfig) : String = { - val returnType = FlinkTypeFactory.toInternalRowTypeInfo(rowType) + val expandedExpressions = calcProgram + .getProjectList + .map(expr => calcProgram.expandLocalRef(expr)) + // time indicator fields must not be part of the code generation + .filter(expr => !FlinkTypeFactory.isTimeIndicatorType(expr.getType)) + // update indices + .map(expr => inputSchema.mapRexNode(expr)) + + val condition = if (calcProgram.getCondition != null) { + inputSchema.mapRexNode(calcProgram.expandLocalRef(calcProgram.getCondition)) + } else { + null + } - val condition = calcProgram.getCondition - val expandedExpressions = calcProgram.getProjectList.map( - expr => calcProgram.expandLocalRef(expr)) val projection = generator.generateResultExpression( - returnType, - rowType.getFieldNames, + returnSchema.physicalTypeInfo, + returnSchema.physicalFieldNames, expandedExpressions) // only projection @@ -60,8 +68,7 @@ trait CommonCalc { |""".stripMargin } else { - val filterCondition = generator.generateExpression( - calcProgram.expandLocalRef(calcProgram.getCondition)) + val filterCondition = generator.generateExpression(condition) // only filter if (projection == null) { s""" http://git-wip-us.apache.org/repos/asf/flink/blob/495f104b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/CommonCorrelate.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/CommonCorrelate.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/CommonCorrelate.scala index 6c4066b..02305ee 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/CommonCorrelate.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/CommonCorrelate.scala @@ -23,11 +23,11 @@ import org.apache.calcite.sql.SemiJoinType import org.apache.flink.api.common.functions.FlatMapFunction import org.apache.flink.api.common.typeinfo.TypeInformation import org.apache.flink.table.api.{TableConfig, TableException} -import org.apache.flink.table.calcite.FlinkTypeFactory import org.apache.flink.table.codegen.CodeGenUtils.primitiveDefaultValue import org.apache.flink.table.codegen.GeneratedExpression.{ALWAYS_NULL, NO_CODE} import org.apache.flink.table.codegen.{CodeGenerator, GeneratedCollector, GeneratedExpression, GeneratedFunction} import org.apache.flink.table.functions.utils.TableSqlFunction +import org.apache.flink.table.plan.schema.RowSchema import org.apache.flink.table.runtime.{CorrelateFlatMapRunner, TableFunctionCollector} import org.apache.flink.types.Row @@ -44,9 +44,9 @@ trait CommonCorrelate { */ private[flink] def correlateMapFunction( config: TableConfig, - inputTypeInfo: TypeInformation[Row], + inputSchema: RowSchema, udtfTypeInfo: TypeInformation[Any], - rowType: RelDataType, + returnSchema: RowSchema, joinType: SemiJoinType, rexCall: RexCall, condition: Option[RexNode], @@ -54,26 +54,24 @@ trait CommonCorrelate { ruleDescription: String) : CorrelateFlatMapRunner[Row, Row] = { - val returnType = FlinkTypeFactory.toInternalRowTypeInfo(rowType) - val flatMap = generateFunction( config, - inputTypeInfo, + inputSchema.physicalTypeInfo, udtfTypeInfo, - returnType, - rowType, + returnSchema.physicalTypeInfo, + returnSchema.logicalFieldNames, joinType, - rexCall, + inputSchema.mapRexNode(rexCall).asInstanceOf[RexCall], pojoFieldMapping, ruleDescription) val collector = generateCollector( config, - inputTypeInfo, + inputSchema.physicalTypeInfo, udtfTypeInfo, - returnType, - rowType, - condition, + returnSchema.physicalTypeInfo, + returnSchema.logicalFieldNames, + condition.map(inputSchema.mapRexNode), pojoFieldMapping) new CorrelateFlatMapRunner[Row, Row]( @@ -93,7 +91,7 @@ trait CommonCorrelate { inputTypeInfo: TypeInformation[Row], udtfTypeInfo: TypeInformation[Any], returnType: TypeInformation[Row], - rowType: RelDataType, + resultFieldNames: Seq[String], joinType: SemiJoinType, rexCall: RexCall, pojoFieldMapping: Option[Array[Int]], @@ -134,7 +132,7 @@ trait CommonCorrelate { x.resultType) } val outerResultExpr = functionGenerator.generateResultExpression( - input1AccessExprs ++ input2NullExprs, returnType, rowType.getFieldNames.asScala) + input1AccessExprs ++ input2NullExprs, returnType, resultFieldNames) body += s""" |boolean hasOutput = $collectorTerm.isCollected(); @@ -162,7 +160,7 @@ trait CommonCorrelate { inputTypeInfo: TypeInformation[Row], udtfTypeInfo: TypeInformation[Any], returnType: TypeInformation[Row], - rowType: RelDataType, + resultFieldNames: Seq[String], condition: Option[RexNode], pojoFieldMapping: Option[Array[Int]]) : GeneratedCollector = { @@ -180,7 +178,7 @@ trait CommonCorrelate { val crossResultExpr = generator.generateResultExpression( input1AccessExprs ++ input2AccessExprs, returnType, - rowType.getFieldNames.asScala) + resultFieldNames) val collectorCode = if (condition.isEmpty) { s""" http://git-wip-us.apache.org/repos/asf/flink/blob/495f104b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/CommonScan.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/CommonScan.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/CommonScan.scala index 0a0d204..091a1ea 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/CommonScan.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/CommonScan.scala @@ -18,11 +18,10 @@ package org.apache.flink.table.plan.nodes -import org.apache.flink.api.common.functions.MapFunction +import org.apache.flink.api.common.functions.Function import org.apache.flink.api.common.typeinfo.TypeInformation import org.apache.flink.table.api.TableConfig -import org.apache.flink.table.codegen.CodeGenerator -import org.apache.flink.table.runtime.MapRunner +import org.apache.flink.table.codegen.{CodeGenerator, GeneratedFunction} import org.apache.flink.types.Row /** @@ -42,21 +41,22 @@ trait CommonScan { externalTypeInfo != internalTypeInfo } - private[flink] def getConversionMapper( + private[flink] def generatedConversionFunction[F <: Function]( config: TableConfig, + functionClass: Class[F], inputType: TypeInformation[Any], expectedType: TypeInformation[Row], conversionOperatorName: String, fieldNames: Seq[String], - inputPojoFieldMapping: Option[Array[Int]] = None) - : MapFunction[Any, Row] = { + inputFieldMapping: Option[Array[Int]] = None) + : GeneratedFunction[F, Row] = { val generator = new CodeGenerator( config, false, inputType, None, - inputPojoFieldMapping) + inputFieldMapping) val conversion = generator.generateConverterResultExpression(expectedType, fieldNames) val body = @@ -65,17 +65,11 @@ trait CommonScan { |return ${conversion.resultTerm}; |""".stripMargin - val genFunction = generator.generateFunction( + generator.generateFunction( conversionOperatorName, - classOf[MapFunction[Any, Row]], + functionClass, body, expectedType) - - new MapRunner[Any, Row]( - genFunction.name, - genFunction.code, - genFunction.returnType) - } } http://git-wip-us.apache.org/repos/asf/flink/blob/495f104b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/OverAggregate.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/OverAggregate.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/OverAggregate.scala index 6878473..1048549 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/OverAggregate.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/OverAggregate.scala @@ -18,14 +18,13 @@ package org.apache.flink.table.plan.nodes -import org.apache.calcite.rel.{RelFieldCollation, RelNode} -import org.apache.calcite.rel.`type`.{RelDataType, RelDataTypeFieldImpl} -import org.apache.calcite.rel.core.AggregateCall +import org.apache.calcite.rel.`type`.RelDataType +import org.apache.calcite.rel.core.{AggregateCall, Window} import org.apache.calcite.rel.core.Window.Group -import org.apache.calcite.rel.core.Window -import org.apache.calcite.rex.{RexInputRef} +import org.apache.calcite.rel.{RelFieldCollation, RelNode} +import org.apache.calcite.rex.RexInputRef +import org.apache.flink.table.plan.schema.RowSchema import org.apache.flink.table.runtime.aggregate.AggregateUtil._ -import org.apache.flink.table.functions.{ProcTimeType, RowTimeType} import scala.collection.JavaConverters._ @@ -43,7 +42,7 @@ trait OverAggregate { val inFields = inputType.getFieldList.asScala val orderingString = orderFields.asScala.map { - x => inFields(x.getFieldIndex).getValue + x => inFields(x.getFieldIndex).getName }.mkString(", ") orderingString @@ -66,24 +65,8 @@ trait OverAggregate { rowType: RelDataType, namedAggregates: Seq[CalcitePair[AggregateCall, String]]): String = { - val inFields = inputType.getFieldList.asScala.map { - x => - x.asInstanceOf[RelDataTypeFieldImpl].getType - match { - case proceTime: ProcTimeType => "PROCTIME" - case rowTime: RowTimeType => "ROWTIME" - case _ => x.asInstanceOf[RelDataTypeFieldImpl].getName - } - } - val outFields = rowType.getFieldList.asScala.map { - x => - x.asInstanceOf[RelDataTypeFieldImpl].getType - match { - case proceTime: ProcTimeType => "PROCTIME" - case rowTime: RowTimeType => "ROWTIME" - case _ => x.asInstanceOf[RelDataTypeFieldImpl].getName - } - } + val inFields = inputType.getFieldNames.asScala + val outFields = rowType.getFieldNames.asScala val aggStrings = namedAggregates.map(_.getKey).map( a => s"${a.getAggregation}(${ @@ -109,7 +92,7 @@ trait OverAggregate { input: RelNode): Long = { val ref: RexInputRef = overWindow.lowerBound.getOffset.asInstanceOf[RexInputRef] - val lowerBoundIndex = input.getRowType.getFieldCount - ref.getIndex; + val lowerBoundIndex = input.getRowType.getFieldCount - ref.getIndex val lowerBound = logicWindow.constants.get(lowerBoundIndex).getValue2 lowerBound match { case x: java.math.BigDecimal => x.asInstanceOf[java.math.BigDecimal].longValue() http://git-wip-us.apache.org/repos/asf/flink/blob/495f104b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/PhysicalTableSourceScan.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/PhysicalTableSourceScan.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/PhysicalTableSourceScan.scala index d924450..c18c3d1 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/PhysicalTableSourceScan.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/PhysicalTableSourceScan.scala @@ -37,9 +37,11 @@ abstract class PhysicalTableSourceScan( override def deriveRowType(): RelDataType = { val flinkTypeFactory = cluster.getTypeFactory.asInstanceOf[FlinkTypeFactory] - flinkTypeFactory.buildRowDataType( + flinkTypeFactory.buildLogicalRowType( TableEnvironment.getFieldNames(tableSource), - TableEnvironment.getFieldTypes(tableSource.getReturnType)) + TableEnvironment.getFieldTypes(tableSource.getReturnType), + None, + None) } override def explainTerms(pw: RelWriter): RelWriter = { http://git-wip-us.apache.org/repos/asf/flink/blob/495f104b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/BatchScan.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/BatchScan.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/BatchScan.scala index b39b8ed..cc5d9fb 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/BatchScan.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/BatchScan.scala @@ -18,11 +18,13 @@ package org.apache.flink.table.plan.nodes.dataset +import org.apache.flink.api.common.functions.MapFunction import org.apache.flink.api.java.DataSet import org.apache.flink.table.api.TableConfig import org.apache.flink.table.calcite.FlinkTypeFactory import org.apache.flink.table.plan.nodes.CommonScan import org.apache.flink.table.plan.schema.FlinkTable +import org.apache.flink.table.runtime.MapRunner import org.apache.flink.types.Row import scala.collection.JavaConversions._ @@ -43,17 +45,23 @@ trait BatchScan extends CommonScan with DataSetRel { // conversion if (needsConversion(inputType, internalType)) { - val mapFunc = getConversionMapper( + val function = generatedConversionFunction( config, + classOf[MapFunction[Any, Row]], inputType, internalType, "DataSetSourceConversion", getRowType.getFieldNames, Some(flinkTable.fieldIndexes)) + val runner = new MapRunner[Any, Row]( + function.name, + function.code, + function.returnType) + val opName = s"from: (${getRowType.getFieldNames.asScala.toList.mkString(", ")})" - input.map(mapFunc).name(opName) + input.map(runner).name(opName) } // no conversion necessary, forward else { http://git-wip-us.apache.org/repos/asf/flink/blob/495f104b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/BatchTableSourceScan.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/BatchTableSourceScan.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/BatchTableSourceScan.scala index bf4291a..fb291e4 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/BatchTableSourceScan.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/BatchTableSourceScan.scala @@ -22,7 +22,8 @@ import org.apache.calcite.plan._ import org.apache.calcite.rel.RelNode import org.apache.calcite.rel.metadata.RelMetadataQuery import org.apache.flink.api.java.DataSet -import org.apache.flink.table.api.BatchTableEnvironment +import org.apache.flink.table.api.{BatchTableEnvironment, TableEnvironment} +import org.apache.flink.table.calcite.FlinkTypeFactory import org.apache.flink.table.plan.nodes.PhysicalTableSourceScan import org.apache.flink.table.plan.schema.TableSourceTable import org.apache.flink.table.sources.{BatchTableSource, TableSource} @@ -37,7 +38,16 @@ class BatchTableSourceScan( extends PhysicalTableSourceScan(cluster, traitSet, table, tableSource) with BatchScan { - override def computeSelfCost(planner: RelOptPlanner, metadata: RelMetadataQuery): RelOptCost = { + override def deriveRowType() = { + val flinkTypeFactory = cluster.getTypeFactory.asInstanceOf[FlinkTypeFactory] + flinkTypeFactory.buildLogicalRowType( + TableEnvironment.getFieldNames(tableSource), + TableEnvironment.getFieldTypes(tableSource.getReturnType), + None, + None) + } + + override def computeSelfCost (planner: RelOptPlanner, metadata: RelMetadataQuery): RelOptCost = { val rowCnt = metadata.getRowCount(this) planner.getCostFactory.makeCost(rowCnt, rowCnt, rowCnt * estimateRowSize(getRowType)) } http://git-wip-us.apache.org/repos/asf/flink/blob/495f104b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetAggregate.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetAggregate.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetAggregate.scala index b92775c..c22dc54 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetAggregate.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetAggregate.scala @@ -91,6 +91,9 @@ class DataSetAggregate( override def translateToPlan(tableEnv: BatchTableEnvironment): DataSet[Row] = { val inputDS = getInput.asInstanceOf[DataSetRel].translateToPlan(tableEnv) + val input = inputNode.asInstanceOf[DataSetRel] + + val rowTypeInfo = FlinkTypeFactory.toInternalRowTypeInfo(getRowType).asInstanceOf[RowTypeInfo] val generator = new CodeGenerator( tableEnv.getConfig, @@ -104,15 +107,14 @@ class DataSetAggregate( ) = AggregateUtil.createDataSetAggregateFunctions( generator, namedAggregates, - inputType, + input.getRowType, + inputDS.getType.asInstanceOf[RowTypeInfo].getFieldTypes, rowRelDataType, grouping, inGroupingSet) val aggString = aggregationToString(inputType, grouping, getRowType, namedAggregates, Nil) - val rowTypeInfo = FlinkTypeFactory.toInternalRowTypeInfo(getRowType).asInstanceOf[RowTypeInfo] - if (grouping.length > 0) { // grouped aggregation val aggOpName = s"groupBy: (${groupingToString(inputType, grouping)}), " + http://git-wip-us.apache.org/repos/asf/flink/blob/495f104b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetCalc.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetCalc.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetCalc.scala index e05b5a8..9e18082 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetCalc.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetCalc.scala @@ -26,10 +26,13 @@ import org.apache.calcite.rel.{RelNode, RelWriter} import org.apache.calcite.rex._ import org.apache.flink.api.common.functions.FlatMapFunction import org.apache.flink.api.java.DataSet +import org.apache.flink.api.java.typeutils.RowTypeInfo import org.apache.flink.table.api.BatchTableEnvironment import org.apache.flink.table.calcite.FlinkTypeFactory import org.apache.flink.table.codegen.CodeGenerator import org.apache.flink.table.plan.nodes.CommonCalc +import org.apache.flink.table.plan.schema.RowSchema +import org.apache.flink.table.runtime.FlatMapRunner import org.apache.flink.types.Row /** @@ -83,14 +86,14 @@ class DataSetCalc( val inputDS = getInput.asInstanceOf[DataSetRel].translateToPlan(tableEnv) - val returnType = FlinkTypeFactory.toInternalRowTypeInfo(getRowType) - val generator = new CodeGenerator(config, false, inputDS.getType) + val rowTypeInfo = FlinkTypeFactory.toInternalRowTypeInfo(getRowType).asInstanceOf[RowTypeInfo] + val body = functionBody( generator, - inputDS.getType, - getRowType, + new RowSchema(getInput.getRowType), + new RowSchema(getRowType), calcProgram, config) @@ -98,9 +101,13 @@ class DataSetCalc( ruleDescription, classOf[FlatMapFunction[Row, Row]], body, - returnType) + rowTypeInfo) + + val runner = new FlatMapRunner[Row, Row]( + genFunction.name, + genFunction.code, + genFunction.returnType) - val mapFunc = calcMapFunction(genFunction) - inputDS.flatMap(mapFunc).name(calcOpName(calcProgram, getExpressionString)) + inputDS.flatMap(runner).name(calcOpName(calcProgram, getExpressionString)) } } http://git-wip-us.apache.org/repos/asf/flink/blob/495f104b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetCorrelate.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetCorrelate.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetCorrelate.scala index 2a62e21..6c79b45 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetCorrelate.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetCorrelate.scala @@ -25,10 +25,13 @@ import org.apache.calcite.rex.{RexCall, RexNode} import org.apache.calcite.sql.SemiJoinType import org.apache.flink.api.common.typeinfo.TypeInformation import org.apache.flink.api.java.DataSet +import org.apache.flink.api.java.typeutils.RowTypeInfo import org.apache.flink.table.api.BatchTableEnvironment +import org.apache.flink.table.calcite.FlinkTypeFactory import org.apache.flink.table.functions.utils.TableSqlFunction import org.apache.flink.table.plan.nodes.CommonCorrelate import org.apache.flink.table.plan.nodes.logical.FlinkLogicalTableFunctionScan +import org.apache.flink.table.plan.schema.RowSchema import org.apache.flink.types.Row /** @@ -98,11 +101,13 @@ class DataSetCorrelate( val pojoFieldMapping = sqlFunction.getPojoFieldMapping val udtfTypeInfo = sqlFunction.getRowTypeInfo.asInstanceOf[TypeInformation[Any]] + val rowTypeInfo = FlinkTypeFactory.toInternalRowTypeInfo(getRowType).asInstanceOf[RowTypeInfo] + val mapFunc = correlateMapFunction( config, - inputDS.getType, + new RowSchema(getInput.getRowType), udtfTypeInfo, - getRowType, + new RowSchema(getRowType), joinType, rexCall, condition, http://git-wip-us.apache.org/repos/asf/flink/blob/495f104b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetWindowAggregate.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetWindowAggregate.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetWindowAggregate.scala index 96c427e..3cb872a 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetWindowAggregate.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetWindowAggregate.scala @@ -24,15 +24,16 @@ import org.apache.calcite.rel.metadata.RelMetadataQuery import org.apache.calcite.rel.{RelNode, RelWriter, SingleRel} import org.apache.flink.api.common.operators.Order import org.apache.flink.api.java.DataSet -import org.apache.flink.api.java.typeutils.ResultTypeQueryable +import org.apache.flink.api.java.typeutils.{ResultTypeQueryable, RowTypeInfo} import org.apache.flink.table.api.BatchTableEnvironment import org.apache.flink.table.calcite.FlinkRelBuilder.NamedWindowProperty import org.apache.flink.table.calcite.FlinkTypeFactory import org.apache.flink.table.codegen.CodeGenerator +import org.apache.flink.table.expressions.ExpressionUtils._ import org.apache.flink.table.plan.logical._ import org.apache.flink.table.plan.nodes.CommonAggregate import org.apache.flink.table.runtime.aggregate.AggregateUtil.{CalcitePair, _} -import org.apache.flink.table.typeutils.TypeCheckUtils.isTimeInterval +import org.apache.flink.table.typeutils.TypeCheckUtils.{isLong, isTimePoint} import org.apache.flink.types.Row /** @@ -106,8 +107,6 @@ class DataSetWindowAggregate( override def translateToPlan(tableEnv: BatchTableEnvironment): DataSet[Row] = { - val config = tableEnv.getConfig - val inputDS = getInput.asInstanceOf[DataSetRel].translateToPlan(tableEnv) val generator = new CodeGenerator( @@ -119,30 +118,31 @@ class DataSetWindowAggregate( val caseSensitive = tableEnv.getFrameworkConfig.getParserConfig.caseSensitive() window match { - case EventTimeTumblingGroupWindow(_, _, size) => + case TumblingGroupWindow(_, timeField, size) + if isTimePoint(timeField.resultType) || isLong(timeField.resultType) => createEventTimeTumblingWindowDataSet( generator, inputDS, - isTimeInterval(size.resultType), + isTimeIntervalLiteral(size), caseSensitive) - case EventTimeSessionGroupWindow(_, _, gap) => + case SessionGroupWindow(_, timeField, gap) + if isTimePoint(timeField.resultType) || isLong(timeField.resultType) => createEventTimeSessionWindowDataSet(generator, inputDS, caseSensitive) - case EventTimeSlidingGroupWindow(_, _, size, slide) => + case SlidingGroupWindow(_, timeField, size, slide) + if isTimePoint(timeField.resultType) || isLong(timeField.resultType) => createEventTimeSlidingWindowDataSet( generator, inputDS, - isTimeInterval(size.resultType), + isTimeIntervalLiteral(size), asLong(size), asLong(slide), caseSensitive) - case _: ProcessingTimeGroupWindow => + case _ => throw new UnsupportedOperationException( - "Processing-time tumbling windows are not supported in a batch environment, " + - "windows in a batch environment must declare a time attribute over which " + - "the query is evaluated.") + s"Window $window is not supported in a batch environment.") } } @@ -152,18 +152,22 @@ class DataSetWindowAggregate( isTimeWindow: Boolean, isParserCaseSensitive: Boolean): DataSet[Row] = { + val input = inputNode.asInstanceOf[DataSetRel] + val mapFunction = createDataSetWindowPrepareMapFunction( generator, window, namedAggregates, grouping, - inputType, + input.getRowType, + inputDS.getType.asInstanceOf[RowTypeInfo].getFieldTypes, isParserCaseSensitive) val groupReduceFunction = createDataSetWindowAggregationGroupReduceFunction( generator, window, namedAggregates, - inputType, + input.getRowType, + inputDS.getType.asInstanceOf[RowTypeInfo].getFieldTypes, getRowType, grouping, namedProperties) @@ -210,6 +214,8 @@ class DataSetWindowAggregate( inputDS: DataSet[Row], isParserCaseSensitive: Boolean): DataSet[Row] = { + val input = inputNode.asInstanceOf[DataSetRel] + val groupingKeys = grouping.indices.toArray val rowTypeInfo = FlinkTypeFactory.toInternalRowTypeInfo(getRowType) @@ -219,7 +225,8 @@ class DataSetWindowAggregate( window, namedAggregates, grouping, - inputType, + input.getRowType, + inputDS.getType.asInstanceOf[RowTypeInfo].getFieldTypes, isParserCaseSensitive) val mappedInput = inputDS.map(mapFunction).name(prepareOperatorName) @@ -245,7 +252,8 @@ class DataSetWindowAggregate( generator, window, namedAggregates, - inputType, + input.getRowType, + inputDS.getType.asInstanceOf[RowTypeInfo].getFieldTypes, grouping) // create groupReduceFunction for calculating the aggregations @@ -253,7 +261,8 @@ class DataSetWindowAggregate( generator, window, namedAggregates, - inputType, + input.getRowType, + inputDS.getType.asInstanceOf[RowTypeInfo].getFieldTypes, rowRelDataType, grouping, namedProperties, @@ -275,7 +284,8 @@ class DataSetWindowAggregate( generator, window, namedAggregates, - inputType, + input.getRowType, + inputDS.getType.asInstanceOf[RowTypeInfo].getFieldTypes, grouping) // create groupReduceFunction for calculating the aggregations @@ -283,7 +293,8 @@ class DataSetWindowAggregate( generator, window, namedAggregates, - inputType, + input.getRowType, + inputDS.getType.asInstanceOf[RowTypeInfo].getFieldTypes, rowRelDataType, grouping, namedProperties, @@ -308,7 +319,8 @@ class DataSetWindowAggregate( generator, window, namedAggregates, - inputType, + input.getRowType, + inputDS.getType.asInstanceOf[RowTypeInfo].getFieldTypes, rowRelDataType, grouping, namedProperties) @@ -324,7 +336,8 @@ class DataSetWindowAggregate( generator, window, namedAggregates, - inputType, + input.getRowType, + inputDS.getType.asInstanceOf[RowTypeInfo].getFieldTypes, rowRelDataType, grouping, namedProperties) @@ -347,6 +360,8 @@ class DataSetWindowAggregate( isParserCaseSensitive: Boolean) : DataSet[Row] = { + val input = inputNode.asInstanceOf[DataSetRel] + // create MapFunction for initializing the aggregations // it aligns the rowtime for pre-tumbling in case of a time-window for partial aggregates val mapFunction = createDataSetWindowPrepareMapFunction( @@ -354,7 +369,8 @@ class DataSetWindowAggregate( window, namedAggregates, grouping, - inputType, + input.getRowType, + inputDS.getType.asInstanceOf[RowTypeInfo].getFieldTypes, isParserCaseSensitive) val mappedDataSet = inputDS @@ -390,7 +406,8 @@ class DataSetWindowAggregate( window, namedAggregates, grouping, - inputType, + input.getRowType, + inputDS.getType.asInstanceOf[RowTypeInfo].getFieldTypes, isParserCaseSensitive) mappedDataSet.asInstanceOf[DataSet[Row]] @@ -426,7 +443,8 @@ class DataSetWindowAggregate( generator, window, namedAggregates, - inputType, + input.getRowType, + inputDS.getType.asInstanceOf[RowTypeInfo].getFieldTypes, rowRelDataType, grouping, namedProperties, http://git-wip-us.apache.org/repos/asf/flink/blob/495f104b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamAggregate.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamAggregate.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamAggregate.scala index c232a71..5697449 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamAggregate.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamAggregate.scala @@ -25,20 +25,18 @@ import org.apache.calcite.rel.{RelNode, RelWriter, SingleRel} import org.apache.flink.api.java.tuple.Tuple import org.apache.flink.streaming.api.datastream.{AllWindowedStream, DataStream, KeyedStream, WindowedStream} import org.apache.flink.streaming.api.windowing.assigners._ -import org.apache.flink.streaming.api.windowing.time.Time import org.apache.flink.streaming.api.windowing.windows.{Window => DataStreamWindow} import org.apache.flink.table.api.StreamTableEnvironment import org.apache.flink.table.calcite.FlinkRelBuilder.NamedWindowProperty -import org.apache.flink.table.calcite.FlinkTypeFactory import org.apache.flink.table.codegen.CodeGenerator -import org.apache.flink.table.expressions._ +import org.apache.flink.table.expressions.ExpressionUtils._ import org.apache.flink.table.plan.logical._ import org.apache.flink.table.plan.nodes.CommonAggregate import org.apache.flink.table.plan.nodes.datastream.DataStreamAggregate._ +import org.apache.flink.table.plan.schema.RowSchema import org.apache.flink.table.runtime.aggregate.AggregateUtil._ import org.apache.flink.table.runtime.aggregate._ import org.apache.flink.table.typeutils.TypeCheckUtils.isTimeInterval -import org.apache.flink.table.typeutils.{RowIntervalTypeInfo, TimeIntervalTypeInfo} import org.apache.flink.types.Row class DataStreamAggregate( @@ -48,12 +46,12 @@ class DataStreamAggregate( traitSet: RelTraitSet, inputNode: RelNode, namedAggregates: Seq[CalcitePair[AggregateCall, String]], - rowRelDataType: RelDataType, - inputType: RelDataType, + schema: RowSchema, + inputSchema: RowSchema, grouping: Array[Int]) extends SingleRel(cluster, traitSet, inputNode) with CommonAggregate with DataStreamRel { - override def deriveRowType(): RelDataType = rowRelDataType + override def deriveRowType(): RelDataType = schema.logicalType override def copy(traitSet: RelTraitSet, inputs: java.util.List[RelNode]): RelNode = { new DataStreamAggregate( @@ -63,22 +61,22 @@ class DataStreamAggregate( traitSet, inputs.get(0), namedAggregates, - getRowType, - inputType, + schema, + inputSchema, grouping) } override def toString: String = { s"Aggregate(${ if (!grouping.isEmpty) { - s"groupBy: (${groupingToString(inputType, grouping)}), " + s"groupBy: (${groupingToString(inputSchema.logicalType, grouping)}), " } else { "" } }window: ($window), " + s"select: (${ aggregationToString( - inputType, + inputSchema.logicalType, grouping, getRowType, namedAggregates, @@ -88,13 +86,13 @@ class DataStreamAggregate( override def explainTerms(pw: RelWriter): RelWriter = { super.explainTerms(pw) - .itemIf("groupBy", groupingToString(inputType, grouping), !grouping.isEmpty) + .itemIf("groupBy", groupingToString(inputSchema.logicalType, grouping), !grouping.isEmpty) .item("window", window) .item( "select", aggregationToString( - inputType, + inputSchema.logicalType, grouping, - getRowType, + schema.logicalType, namedAggregates, namedProperties)) } @@ -102,17 +100,20 @@ class DataStreamAggregate( override def translateToPlan(tableEnv: StreamTableEnvironment): DataStream[Row] = { val inputDS = input.asInstanceOf[DataStreamRel].translateToPlan(tableEnv) - - val rowTypeInfo = FlinkTypeFactory.toInternalRowTypeInfo(getRowType) + val physicalNamedAggregates = namedAggregates.map { namedAggregate => + new CalcitePair[AggregateCall, String]( + inputSchema.mapAggregateCall(namedAggregate.left), + namedAggregate.right) + } val aggString = aggregationToString( - inputType, + inputSchema.logicalType, grouping, - getRowType, + schema.logicalType, namedAggregates, namedProperties) - val keyedAggOpName = s"groupBy: (${groupingToString(inputType, grouping)}), " + + val keyedAggOpName = s"groupBy: (${groupingToString(schema.logicalType, grouping)}), " + s"window: ($window), " + s"select: ($aggString)" val nonKeyedAggOpName = s"window: ($window), select: ($aggString)" @@ -123,21 +124,21 @@ class DataStreamAggregate( inputDS.getType) val needMerge = window match { - case ProcessingTimeSessionGroupWindow(_, _) => true - case EventTimeSessionGroupWindow(_, _, _) => true + case SessionGroupWindow(_, _, _) => true case _ => false } + val physicalGrouping = grouping.map(inputSchema.mapIndex) // grouped / keyed aggregation - if (grouping.length > 0) { + if (physicalGrouping.length > 0) { val windowFunction = AggregateUtil.createAggregationGroupWindowFunction( window, - grouping.length, - namedAggregates.size, - rowRelDataType.getFieldCount, + physicalGrouping.length, + physicalNamedAggregates.size, + schema.physicalArity, namedProperties) - val keyedStream = inputDS.keyBy(grouping: _*) + val keyedStream = inputDS.keyBy(physicalGrouping: _*) val windowedStream = createKeyedWindowedStream(window, keyedStream) .asInstanceOf[WindowedStream[Row, Tuple, DataStreamWindow]] @@ -145,20 +146,26 @@ class DataStreamAggregate( val (aggFunction, accumulatorRowType, aggResultRowType) = AggregateUtil.createDataStreamAggregateFunction( generator, - namedAggregates, - inputType, - rowRelDataType, + physicalNamedAggregates, + inputSchema.physicalType, + inputSchema.physicalFieldTypeInfo, + schema.physicalType, needMerge) windowedStream - .aggregate(aggFunction, windowFunction, accumulatorRowType, aggResultRowType, rowTypeInfo) + .aggregate( + aggFunction, + windowFunction, + accumulatorRowType, + aggResultRowType, + schema.physicalTypeInfo) .name(keyedAggOpName) } // global / non-keyed aggregation else { val windowFunction = AggregateUtil.createAggregationAllWindowFunction( window, - rowRelDataType.getFieldCount, + schema.physicalArity, namedProperties) val windowedStream = @@ -168,13 +175,19 @@ class DataStreamAggregate( val (aggFunction, accumulatorRowType, aggResultRowType) = AggregateUtil.createDataStreamAggregateFunction( generator, - namedAggregates, - inputType, - rowRelDataType, + physicalNamedAggregates, + inputSchema.physicalType, + inputSchema.physicalFieldTypeInfo, + schema.physicalType, needMerge) windowedStream - .aggregate(aggFunction, windowFunction, accumulatorRowType, aggResultRowType, rowTypeInfo) + .aggregate( + aggFunction, + windowFunction, + accumulatorRowType, + aggResultRowType, + schema.physicalTypeInfo) .name(nonKeyedAggOpName) } } @@ -186,95 +199,102 @@ object DataStreamAggregate { private def createKeyedWindowedStream(groupWindow: LogicalWindow, stream: KeyedStream[Row, Tuple]) : WindowedStream[Row, Tuple, _ <: DataStreamWindow] = groupWindow match { - case ProcessingTimeTumblingGroupWindow(_, size) if isTimeInterval(size.resultType) => - stream.window(TumblingProcessingTimeWindows.of(asTime(size))) + case TumblingGroupWindow(_, timeField, size) + if isProctimeAttribute(timeField) && isTimeIntervalLiteral(size)=> + stream.window(TumblingProcessingTimeWindows.of(toTime(size))) - case ProcessingTimeTumblingGroupWindow(_, size) => - stream.countWindow(asCount(size)) + case TumblingGroupWindow(_, timeField, size) + if isProctimeAttribute(timeField) && isRowCountLiteral(size)=> + stream.countWindow(toLong(size)) - case EventTimeTumblingGroupWindow(_, _, size) if isTimeInterval(size.resultType) => - stream.window(TumblingEventTimeWindows.of(asTime(size))) + case TumblingGroupWindow(_, timeField, size) + if isRowtimeAttribute(timeField) && isTimeIntervalLiteral(size) => + stream.window(TumblingEventTimeWindows.of(toTime(size))) - case EventTimeTumblingGroupWindow(_, _, size) => + case TumblingGroupWindow(_, _, size) => // TODO: EventTimeTumblingGroupWindow should sort the stream on event time // before applying the windowing logic. Otherwise, this would be the same as a // ProcessingTimeTumblingGroupWindow throw new UnsupportedOperationException( "Event-time grouping windows on row intervals are currently not supported.") - case ProcessingTimeSlidingGroupWindow(_, size, slide) if isTimeInterval(size.resultType) => - stream.window(SlidingProcessingTimeWindows.of(asTime(size), asTime(slide))) + case SlidingGroupWindow(_, timeField, size, slide) + if isProctimeAttribute(timeField) && isTimeIntervalLiteral(slide) => + stream.window(SlidingProcessingTimeWindows.of(toTime(size), toTime(slide))) - case ProcessingTimeSlidingGroupWindow(_, size, slide) => - stream.countWindow(asCount(size), asCount(slide)) + case SlidingGroupWindow(_, timeField, size, slide) + if isProctimeAttribute(timeField) && isRowCountLiteral(size) => + stream.countWindow(toLong(size), toLong(slide)) - case EventTimeSlidingGroupWindow(_, _, size, slide) if isTimeInterval(size.resultType) => - stream.window(SlidingEventTimeWindows.of(asTime(size), asTime(slide))) + case SlidingGroupWindow(_, timeField, size, slide) + if isRowtimeAttribute(timeField) && isTimeIntervalLiteral(size)=> + stream.window(SlidingEventTimeWindows.of(toTime(size), toTime(slide))) - case EventTimeSlidingGroupWindow(_, _, size, slide) => + case SlidingGroupWindow(_, _, size, slide) => // TODO: EventTimeTumblingGroupWindow should sort the stream on event time // before applying the windowing logic. Otherwise, this would be the same as a // ProcessingTimeTumblingGroupWindow throw new UnsupportedOperationException( "Event-time grouping windows on row intervals are currently not supported.") - case ProcessingTimeSessionGroupWindow(_, gap: Expression) => - stream.window(ProcessingTimeSessionWindows.withGap(asTime(gap))) + case SessionGroupWindow(_, timeField, gap) + if isProctimeAttribute(timeField) => + stream.window(ProcessingTimeSessionWindows.withGap(toTime(gap))) - case EventTimeSessionGroupWindow(_, _, gap) => - stream.window(EventTimeSessionWindows.withGap(asTime(gap))) + case SessionGroupWindow(_, timeField, gap) + if isRowtimeAttribute(timeField) => + stream.window(EventTimeSessionWindows.withGap(toTime(gap))) } private def createNonKeyedWindowedStream(groupWindow: LogicalWindow, stream: DataStream[Row]) : AllWindowedStream[Row, _ <: DataStreamWindow] = groupWindow match { - case ProcessingTimeTumblingGroupWindow(_, size) if isTimeInterval(size.resultType) => - stream.windowAll(TumblingProcessingTimeWindows.of(asTime(size))) + case TumblingGroupWindow(_, timeField, size) + if isProctimeAttribute(timeField) && isTimeIntervalLiteral(size) => + stream.windowAll(TumblingProcessingTimeWindows.of(toTime(size))) - case ProcessingTimeTumblingGroupWindow(_, size) => - stream.countWindowAll(asCount(size)) + case TumblingGroupWindow(_, timeField, size) + if isProctimeAttribute(timeField) && isRowCountLiteral(size)=> + stream.countWindowAll(toLong(size)) - case EventTimeTumblingGroupWindow(_, _, size) if isTimeInterval(size.resultType) => - stream.windowAll(TumblingEventTimeWindows.of(asTime(size))) + case TumblingGroupWindow(_, _, size) if isTimeInterval(size.resultType) => + stream.windowAll(TumblingEventTimeWindows.of(toTime(size))) - case EventTimeTumblingGroupWindow(_, _, size) => + case TumblingGroupWindow(_, _, size) => // TODO: EventTimeTumblingGroupWindow should sort the stream on event time // before applying the windowing logic. Otherwise, this would be the same as a // ProcessingTimeTumblingGroupWindow throw new UnsupportedOperationException( "Event-time grouping windows on row intervals are currently not supported.") - case ProcessingTimeSlidingGroupWindow(_, size, slide) if isTimeInterval(size.resultType) => - stream.windowAll(SlidingProcessingTimeWindows.of(asTime(size), asTime(slide))) + case SlidingGroupWindow(_, timeField, size, slide) + if isProctimeAttribute(timeField) && isTimeIntervalLiteral(size) => + stream.windowAll(SlidingProcessingTimeWindows.of(toTime(size), toTime(slide))) - case ProcessingTimeSlidingGroupWindow(_, size, slide) => - stream.countWindowAll(asCount(size), asCount(slide)) + case SlidingGroupWindow(_, timeField, size, slide) + if isProctimeAttribute(timeField) && isRowCountLiteral(size)=> + stream.countWindowAll(toLong(size), toLong(slide)) - case EventTimeSlidingGroupWindow(_, _, size, slide) if isTimeInterval(size.resultType) => - stream.windowAll(SlidingEventTimeWindows.of(asTime(size), asTime(slide))) + case SlidingGroupWindow(_, timeField, size, slide) + if isRowtimeAttribute(timeField) && isTimeIntervalLiteral(size)=> + stream.windowAll(SlidingEventTimeWindows.of(toTime(size), toTime(slide))) - case EventTimeSlidingGroupWindow(_, _, size, slide) => + case SlidingGroupWindow(_, _, size, slide) => // TODO: EventTimeTumblingGroupWindow should sort the stream on event time // before applying the windowing logic. Otherwise, this would be the same as a // ProcessingTimeTumblingGroupWindow throw new UnsupportedOperationException( "Event-time grouping windows on row intervals are currently not supported.") - case ProcessingTimeSessionGroupWindow(_, gap) => - stream.windowAll(ProcessingTimeSessionWindows.withGap(asTime(gap))) + case SessionGroupWindow(_, timeField, gap) + if isProctimeAttribute(timeField) && isTimeIntervalLiteral(gap) => + stream.windowAll(ProcessingTimeSessionWindows.withGap(toTime(gap))) - case EventTimeSessionGroupWindow(_, _, gap) => - stream.windowAll(EventTimeSessionWindows.withGap(asTime(gap))) + case SessionGroupWindow(_, timeField, gap) + if isRowtimeAttribute(timeField) && isTimeIntervalLiteral(gap) => + stream.windowAll(EventTimeSessionWindows.withGap(toTime(gap))) } - def asTime(expr: Expression): Time = expr match { - case Literal(value: Long, TimeIntervalTypeInfo.INTERVAL_MILLIS) => Time.milliseconds(value) - case _ => throw new IllegalArgumentException() - } - def asCount(expr: Expression): Long = expr match { - case Literal(value: Long, RowIntervalTypeInfo.INTERVAL_ROWS) => value - case _ => throw new IllegalArgumentException() - } } http://git-wip-us.apache.org/repos/asf/flink/blob/495f104b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamCalc.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamCalc.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamCalc.scala index b015a1d..c6c25c0 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamCalc.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamCalc.scala @@ -27,9 +27,9 @@ import org.apache.calcite.rex.RexProgram import org.apache.flink.api.common.functions.FlatMapFunction import org.apache.flink.streaming.api.datastream.DataStream import org.apache.flink.table.api.StreamTableEnvironment -import org.apache.flink.table.calcite.FlinkTypeFactory import org.apache.flink.table.codegen.CodeGenerator import org.apache.flink.table.plan.nodes.CommonCalc +import org.apache.flink.table.plan.schema.RowSchema import org.apache.flink.types.Row /** @@ -40,17 +40,25 @@ class DataStreamCalc( cluster: RelOptCluster, traitSet: RelTraitSet, input: RelNode, - rowRelDataType: RelDataType, + inputSchema: RowSchema, + schema: RowSchema, calcProgram: RexProgram, ruleDescription: String) extends Calc(cluster, traitSet, input, calcProgram) with CommonCalc with DataStreamRel { - override def deriveRowType(): RelDataType = rowRelDataType + override def deriveRowType(): RelDataType = schema.logicalType override def copy(traitSet: RelTraitSet, child: RelNode, program: RexProgram): Calc = { - new DataStreamCalc(cluster, traitSet, child, getRowType, program, ruleDescription) + new DataStreamCalc( + cluster, + traitSet, + child, + inputSchema, + schema, + program, + ruleDescription) } override def toString: String = calcToString(calcProgram, getExpressionString) @@ -85,8 +93,8 @@ class DataStreamCalc( val body = functionBody( generator, - inputDataStream.getType, - getRowType, + inputSchema, + schema, calcProgram, config) @@ -94,7 +102,7 @@ class DataStreamCalc( ruleDescription, classOf[FlatMapFunction[Row, Row]], body, - FlinkTypeFactory.toInternalRowTypeInfo(getRowType)) + schema.physicalTypeInfo) val mapFunc = calcMapFunction(genFunction) inputDataStream.flatMap(mapFunc).name(calcOpName(calcProgram, getExpressionString)) http://git-wip-us.apache.org/repos/asf/flink/blob/495f104b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamCorrelate.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamCorrelate.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamCorrelate.scala index 342920a..8955110 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamCorrelate.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamCorrelate.scala @@ -18,7 +18,6 @@ package org.apache.flink.table.plan.nodes.datastream import org.apache.calcite.plan.{RelOptCluster, RelTraitSet} -import org.apache.calcite.rel.`type`.RelDataType import org.apache.calcite.rel.{RelNode, RelWriter, SingleRel} import org.apache.calcite.rex.{RexCall, RexNode} import org.apache.calcite.sql.SemiJoinType @@ -28,6 +27,7 @@ import org.apache.flink.table.api.StreamTableEnvironment import org.apache.flink.table.functions.utils.TableSqlFunction import org.apache.flink.table.plan.nodes.CommonCorrelate import org.apache.flink.table.plan.nodes.logical.FlinkLogicalTableFunctionScan +import org.apache.flink.table.plan.schema.RowSchema import org.apache.flink.types.Row /** @@ -36,28 +36,30 @@ import org.apache.flink.types.Row class DataStreamCorrelate( cluster: RelOptCluster, traitSet: RelTraitSet, + inputSchema: RowSchema, inputNode: RelNode, scan: FlinkLogicalTableFunctionScan, condition: Option[RexNode], - relRowType: RelDataType, - joinRowType: RelDataType, + schema: RowSchema, + joinSchema: RowSchema, joinType: SemiJoinType, ruleDescription: String) extends SingleRel(cluster, traitSet, inputNode) with CommonCorrelate with DataStreamRel { - override def deriveRowType() = relRowType + override def deriveRowType() = schema.logicalType override def copy(traitSet: RelTraitSet, inputs: java.util.List[RelNode]): RelNode = { new DataStreamCorrelate( cluster, traitSet, + inputSchema, inputs.get(0), scan, condition, - relRowType, - joinRowType, + schema, + joinSchema, joinType, ruleDescription) } @@ -74,7 +76,7 @@ class DataStreamCorrelate( super.explainTerms(pw) .item("invocation", scan.getCall) .item("function", sqlFunction.getTableFunction.getClass.getCanonicalName) - .item("rowType", relRowType) + .item("rowType", schema.logicalType) .item("joinType", joinType) .itemIf("condition", condition.orNull, condition.isDefined) } @@ -94,16 +96,16 @@ class DataStreamCorrelate( val mapFunc = correlateMapFunction( config, - inputDS.getType, + inputSchema, udtfTypeInfo, - getRowType, + schema, joinType, rexCall, condition, Some(pojoFieldMapping), ruleDescription) - inputDS.flatMap(mapFunc).name(correlateOpName(rexCall, sqlFunction, relRowType)) + inputDS.flatMap(mapFunc).name(correlateOpName(rexCall, sqlFunction, schema.logicalType)) } }