This is an automated email from the ASF dual-hosted git repository. twalthr pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/flink.git
commit 9f7eef293f723800945a9759c50adbf8786a2bd4 Author: slinkydeveloper <francescogu...@gmail.com> AuthorDate: Tue Nov 16 10:48:08 2021 +0100 [FLINK-24781][table-planner] Refactor cast of literals to use CastExecutor Signed-off-by: slinkydeveloper <francescogu...@gmail.com> This closes #17800. --- .../CodeGeneratedExpressionCastExecutor.java | 3 +- .../flink/table/planner/codegen/CodeGenUtils.scala | 26 ++++++- .../table/planner/codegen/GenerateUtils.scala | 16 ---- .../planner/codegen/calls/BuiltInMethods.scala | 1 - .../table/planner/codegen/calls/IfCallGen.scala | 7 +- .../planner/codegen/calls/ScalarOperatorGens.scala | 89 ++++++++++++---------- .../validation/ScalarOperatorsValidationTest.scala | 12 +-- 7 files changed, 85 insertions(+), 69 deletions(-) diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/casting/CodeGeneratedExpressionCastExecutor.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/casting/CodeGeneratedExpressionCastExecutor.java index 7c361ac..6e57593 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/casting/CodeGeneratedExpressionCastExecutor.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/casting/CodeGeneratedExpressionCastExecutor.java @@ -57,7 +57,8 @@ class CodeGeneratedExpressionCastExecutor<IN, OUT> implements CastExecutor<IN, O throw (TableException) e.getCause(); } throw new TableException( - "Cannot execute the compiled expression for an unknown cause", e); + "Cannot execute the compiled expression for an unknown cause. " + e.getCause(), + e); } } } diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/CodeGenUtils.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/CodeGenUtils.scala index 22bb463..b21d097 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/CodeGenUtils.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/CodeGenUtils.scala @@ -21,7 +21,6 @@ package org.apache.flink.table.planner.codegen import java.lang.reflect.Method import java.lang.{Boolean => JBoolean, Byte => JByte, Double => JDouble, Float => JFloat, Integer => JInt, Long => JLong, Object => JObject, Short => JShort} import java.util.concurrent.atomic.AtomicLong - import org.apache.flink.api.common.ExecutionConfig import org.apache.flink.api.common.functions.RuntimeContext import org.apache.flink.core.memory.MemorySegment @@ -33,10 +32,10 @@ import org.apache.flink.table.data.util.DataFormatConverters.IdentityConverter import org.apache.flink.table.data.utils.JoinedRowData import org.apache.flink.table.functions.UserDefinedFunction import org.apache.flink.table.planner.codegen.GenerateUtils.{generateInputFieldUnboxing, generateNonNullField} +import org.apache.flink.table.planner.codegen.calls.BuiltInMethods.BINARY_STRING_DATA_FROM_STRING import org.apache.flink.table.runtime.dataview.StateDataViewStore import org.apache.flink.table.runtime.generated.{AggsHandleFunction, HashFunction, NamespaceAggsHandleFunction, TableAggsHandleFunction} import org.apache.flink.table.runtime.types.LogicalTypeDataTypeConverter.fromDataTypeToLogicalType -import org.apache.flink.table.runtime.types.PlannerTypeUtils.isInteroperable import org.apache.flink.table.runtime.typeutils.TypeCheckUtils import org.apache.flink.table.runtime.util.{MurmurHashUtil, TimeWindowUtil} import org.apache.flink.table.types.DataType @@ -46,6 +45,7 @@ import org.apache.flink.table.types.logical.utils.LogicalTypeChecks import org.apache.flink.table.types.logical.utils.LogicalTypeChecks.{getFieldCount, getPrecision, getScale} import org.apache.flink.table.types.logical.utils.LogicalTypeUtils.toInternalConversionClass import org.apache.flink.table.types.utils.DataTypeUtils.isInternal +import org.apache.flink.table.utils.EncodingUtils import org.apache.flink.types.{Row, RowKind} import scala.annotation.tailrec @@ -195,6 +195,28 @@ object CodeGenUtils { case _ => boxedTypeTermForType(t) } + /** + * Converts values to stringified representation to include in the codegen. + * + * This method doesn't support complex types. + */ + def primitiveLiteralForType(value: Any): String = value match { + // ordered by type root definition + case _: JBoolean => value.toString + case _: JByte => s"((byte)$value)" + case _: JShort => s"((short)$value)" + case _: JInt => value.toString + case _: JLong => value.toString + "L" + case _: JFloat => value.toString + "f" + case _: JDouble => value.toString + "d" + case sd: StringData => + qualifyMethod(BINARY_STRING_DATA_FROM_STRING) + "(\"" + + EncodingUtils.escapeJava(sd.toString) + "\")" + case td: TimestampData => + s"$TIMESTAMP_DATA.fromEpochMillis(${td.getMillisecond}L, ${td.getNanoOfMillisecond})" + case _ => throw new IllegalArgumentException("Illegal literal type: " + value.getClass) + } + @tailrec def boxedTypeTermForType(t: LogicalType): String = t.getTypeRoot match { // ordered by type root definition diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/GenerateUtils.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/GenerateUtils.scala index d113953..cc612ac 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/GenerateUtils.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/GenerateUtils.scala @@ -142,22 +142,6 @@ object GenerateUtils { /** - * Generates a string result call with auxiliary statements and result expression. - * This will convert the String result to BinaryStringData. - */ - def generateStringResultCallWithStmtIfArgsNotNull( - ctx: CodeGeneratorContext, - operands: Seq[GeneratedExpression], - returnType: LogicalType) - (call: Seq[String] => (String, String)): GeneratedExpression = { - generateCallWithStmtIfArgsNotNull(ctx, returnType, operands) { - args => - val (stmt, result) = call(args) - (stmt, s"$BINARY_STRING.fromString($result)") - } - } - - /** * Generates a call with the nullable args. */ def generateCallIfArgsNullable( diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/calls/BuiltInMethods.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/calls/BuiltInMethods.scala index 308826d..824f362 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/calls/BuiltInMethods.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/calls/BuiltInMethods.scala @@ -29,7 +29,6 @@ import org.apache.flink.table.data.binary.{BinaryStringData, BinaryStringDataUti import java.lang.reflect.Method import java.lang.{Byte => JByte, Integer => JInteger, Long => JLong, Short => JShort} -import java.time.ZoneId import java.util.TimeZone object BuiltInMethods { diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/calls/IfCallGen.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/calls/IfCallGen.scala index af8061c..5fe1dd1 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/calls/IfCallGen.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/calls/IfCallGen.scala @@ -19,10 +19,9 @@ package org.apache.flink.table.planner.codegen.calls import org.apache.flink.table.planner.codegen.CodeGenUtils.{className, primitiveDefaultValue, primitiveTypeTermForType} -import org.apache.flink.table.planner.codegen.calls.ScalarOperatorGens.toCastContext -import org.apache.flink.table.planner.codegen.{CodeGenException, CodeGenUtils, CodeGeneratorContext, GeneratedExpression} +import org.apache.flink.table.planner.codegen.calls.ScalarOperatorGens.toCodegenCastContext +import org.apache.flink.table.planner.codegen.{CodeGeneratorContext, GeneratedExpression} import org.apache.flink.table.planner.functions.casting.{CastRuleProvider, ExpressionCodeGeneratorCastRule} -import org.apache.flink.table.runtime.types.PlannerTypeUtils.isInteroperable import org.apache.flink.table.types.logical.LogicalType /** @@ -86,7 +85,7 @@ class IfCallGen() extends CallGenerator { rule match { case codeGeneratorCastRule: ExpressionCodeGeneratorCastRule[_, _] => codeGeneratorCastRule.generateExpression( - toCastContext(ctx), + toCodegenCastContext(ctx), expr.resultTerm, expr.resultType, targetType diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/calls/ScalarOperatorGens.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/calls/ScalarOperatorGens.scala index 8554e4f..dc8bb63 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/calls/ScalarOperatorGens.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/calls/ScalarOperatorGens.scala @@ -20,13 +20,14 @@ package org.apache.flink.table.planner.codegen.calls import org.apache.flink.table.api.{TableException, ValidationException} import org.apache.flink.table.data.binary.BinaryArrayData -import org.apache.flink.table.planner.functions.casting.{CastRuleProvider, CodeGeneratorCastRule, ExpressionCodeGeneratorCastRule} +import org.apache.flink.table.planner.functions.casting.{CastRule, CastRuleProvider, CodeGeneratorCastRule, ExpressionCodeGeneratorCastRule} import org.apache.flink.table.data.util.MapDataUtil +import org.apache.flink.table.data.utils.CastExecutor import org.apache.flink.table.data.writer.{BinaryArrayWriter, BinaryRowWriter} import org.apache.flink.table.planner.codegen.CodeGenUtils.{binaryRowFieldSetAccess, binaryRowSetNull, binaryWriterWriteField, binaryWriterWriteNull, _} import org.apache.flink.table.planner.codegen.GenerateUtils._ import org.apache.flink.table.planner.codegen.GeneratedExpression.{ALWAYS_NULL, NEVER_NULL, NO_CODE} -import org.apache.flink.table.planner.codegen.{CodeGenException, CodeGenUtils, CodeGeneratorContext, GeneratedExpression} +import org.apache.flink.table.planner.codegen.{CodeGenException, CodeGeneratorContext, GeneratedExpression} import org.apache.flink.table.planner.utils.JavaScalaConversionUtil.toScala import org.apache.flink.table.runtime.functions.SqlFunctionUtils import org.apache.flink.table.runtime.types.PlannerTypeUtils @@ -42,6 +43,7 @@ import org.apache.flink.table.utils.DateTimeUtils import org.apache.flink.util.Preconditions.checkArgument import org.apache.flink.table.utils.DateTimeUtils.MILLIS_PER_DAY +import java.time.ZoneId import java.util.Arrays.asList import scala.collection.JavaConversions._ @@ -487,7 +489,7 @@ object ScalarOperatorGens { // for performance, we cast literal string to literal time. else if (isTimePoint(left.resultType) && isCharacterString(right.resultType)) { if (right.literal) { - generateEquals(ctx, left, generateCastStringLiteralToDateTime(ctx, right, left.resultType)) + generateEquals(ctx, left, generateCastLiteral(ctx, right, left.resultType)) } else { generateEquals(ctx, left, generateCast(ctx, right, left.resultType)) } @@ -496,7 +498,7 @@ object ScalarOperatorGens { if (left.literal) { generateEquals( ctx, - generateCastStringLiteralToDateTime(ctx, left, right.resultType), + generateCastLiteral(ctx, left, right.resultType), right) } else { generateEquals(ctx, generateCast(ctx, left, right.resultType), right) @@ -946,7 +948,7 @@ object ScalarOperatorGens { // Generate the code block val castCodeBlock = codeGeneratorCastRule.generateCodeBlock( - toCastContext(ctx), + toCodegenCastContext(ctx), operand.resultTerm, operand.nullTerm, inputType, @@ -1942,42 +1944,43 @@ object ScalarOperatorGens { } } - private def generateCastStringLiteralToDateTime( - ctx: CodeGeneratorContext, - stringLiteral: GeneratedExpression, - expectType: LogicalType): GeneratedExpression = { - checkArgument(stringLiteral.literal) - if (java.lang.Boolean.valueOf(stringLiteral.nullTerm)) { - return generateNullLiteral(expectType, nullCheck = true) + /** + * This method supports casting literals to non-composite types (primitives, strings, date time). + * Every cast result is declared as class member, in order to be able to reuse it. + */ + private def generateCastLiteral( + ctx: CodeGeneratorContext, + literalExpr: GeneratedExpression, + resultType: LogicalType): GeneratedExpression = { + checkArgument(literalExpr.literal) + if (java.lang.Boolean.valueOf(literalExpr.nullTerm)) { + return generateNullLiteral(resultType, nullCheck = true) } - val stringValue = stringLiteral.literalValue.get.toString - val literalCode = expectType.getTypeRoot match { - case DATE => - DateTimeUtils.dateStringToUnixDate(stringValue) match { - case null => throw new ValidationException(s"String '$stringValue' is not a valid date") - case v => v - } - case TIME_WITHOUT_TIME_ZONE => - DateTimeUtils.timeStringToUnixDate(stringValue) match { - case null => throw new ValidationException(s"String '$stringValue' is not a valid time") - case v => v - } - case TIMESTAMP_WITHOUT_TIME_ZONE => - DateTimeUtils.toTimestampData(stringValue) match { - case null => - throw new ValidationException(s"String '$stringValue' is not a valid timestamp") - case v => s"${CodeGenUtils.TIMESTAMP_DATA}.fromEpochMillis(" + - s"${v.getMillisecond}L, ${v.getNanoOfMillisecond})" - } - case _ => throw new UnsupportedOperationException + val castExecutor = CastRuleProvider.create( + toCastContext(ctx), + literalExpr.resultType, + resultType + ).asInstanceOf[CastExecutor[Any, Any]] + + if (castExecutor == null) { + throw new CodeGenException( + s"Unsupported casting from ${literalExpr.resultType} to $resultType") } - val typeTerm = primitiveTypeTermForType(expectType) - val resultTerm = newName("stringToTime") - val stmt = s"$typeTerm $resultTerm = $literalCode;" - ctx.addReusableMember(stmt) - GeneratedExpression(resultTerm, "false", "", expectType) + try { + val result = castExecutor.cast(literalExpr.literalValue.get) + val resultTerm = newName("stringToTime") + + val declStmt = + s"${primitiveTypeTermForType(resultType)} $resultTerm = ${primitiveLiteralForType(result)};" + + ctx.addReusableMember(declStmt) + GeneratedExpression(resultTerm, "false", "", resultType, Some(result)) + } catch { + case e: Throwable => + throw new ValidationException("Error when casting literal: " + e.getMessage, e) + } } private def generateArrayComparison( @@ -2169,7 +2172,7 @@ object ScalarOperatorGens { rule match { case codeGeneratorCastRule: ExpressionCodeGeneratorCastRule[_, _] => operandTerm => codeGeneratorCastRule.generateExpression( - toCastContext(ctx), + toCodegenCastContext(ctx), operandTerm, operandType, resultType @@ -2179,7 +2182,7 @@ object ScalarOperatorGens { } } - def toCastContext(ctx: CodeGeneratorContext): CodeGeneratorCastRule.Context = { + def toCodegenCastContext(ctx: CodeGeneratorContext): CodeGeneratorCastRule.Context = { new CodeGeneratorCastRule.Context { override def getSessionTimeZoneTerm: String = ctx.addReusableSessionTimeZone() override def declareVariable(ty: String, variablePrefix: String): String = @@ -2193,4 +2196,12 @@ object ScalarOperatorGens { } } + def toCastContext(ctx: CodeGeneratorContext): CastRule.Context = { + new CastRule.Context { + override def getSessionZoneId: ZoneId = ctx.tableConfig.getLocalTimeZone + + override def getClassLoader: ClassLoader = Thread.currentThread().getContextClassLoader + } + } + } diff --git a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/expressions/validation/ScalarOperatorsValidationTest.scala b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/expressions/validation/ScalarOperatorsValidationTest.scala index 4b27008..5fc4b72 100644 --- a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/expressions/validation/ScalarOperatorsValidationTest.scala +++ b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/expressions/validation/ScalarOperatorsValidationTest.scala @@ -88,24 +88,24 @@ class ScalarOperatorsValidationTest extends ScalarOperatorsTestBase { @Test def testTemporalTypeEqualsInvalidStringLiteral(): Unit = { testExpectedSqlException( - "f15 = 'invalid'", "is not a valid date", + "f15 = 'invalid'", "java.time.DateTimeException", classOf[ValidationException]) testExpectedSqlException( - "'invalid' = f15", "is not a valid date", + "'invalid' = f15", "java.time.DateTimeException", classOf[ValidationException]) testExpectedSqlException( - "f21 = 'invalid'", "is not a valid time", + "f21 = 'invalid'", "java.time.DateTimeException", classOf[ValidationException]) testExpectedSqlException( - "'invalid' = f21", "is not a valid time", + "'invalid' = f21", "java.time.DateTimeException", classOf[ValidationException]) testExpectedSqlException( - "f22 = 'invalid'", "is not a valid timestamp", + "f22 = 'invalid'", "java.time.DateTimeException", classOf[ValidationException]) testExpectedSqlException( - "'invalid' = f22", "is not a valid timestamp", + "'invalid' = f22", "java.time.DateTimeException", classOf[ValidationException]) } }