This is an automated email from the ASF dual-hosted git repository. twalthr pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/flink.git
commit 09aad58943a1597466a78ebb8d543e6baa3f5092 Author: Marios Trivyzas <mat...@gmail.com> AuthorDate: Thu Dec 9 13:52:38 2021 +0100 [FLINK-24413][table] Apply trimming & padding when CASTing to CHAR/VARCHAR Apply trimming when CASTing to `CHAR(<length>)` or `VARCHAR(<length>)` and the length of the result string exceeds the length specified. Apply padding to the right with spaces when CASTing to `CHAR(<length>)` and the result string's length is less than the specified length, so that the length of result string matches exactly the length. This closes #18063. --- .../flink/table/types/logical/VarCharType.java | 2 + .../functions/casting/ArrayToStringCastRule.java | 187 +++++++----- .../functions/casting/BinaryToStringCastRule.java | 3 +- .../functions/casting/BooleanToStringCastRule.java | 3 +- .../functions/casting/CastRulePredicate.java | 52 ++-- .../functions/casting/CastRuleProvider.java | 23 +- .../casting/CharVarCharTrimPadCastRule.java | 252 ++++++++++++++++ .../functions/casting/DateToStringCastRule.java | 7 +- .../casting/IntervalToStringCastRule.java | 3 +- .../casting/MapAndMultisetToStringCastRule.java | 300 +++++++++++-------- .../functions/casting/NumericToStringCastRule.java | 3 +- .../functions/casting/RawToStringCastRule.java | 54 +++- .../functions/casting/RowToStringCastRule.java | 78 +++-- .../functions/casting/TimeToStringCastRule.java | 3 +- .../casting/TimestampToStringCastRule.java | 3 +- .../table/planner/codegen/calls/IfCallGen.scala | 23 +- .../planner/functions/CastFunctionITCase.java | 29 +- .../functions/casting/CastRuleProviderTest.java | 19 ++ .../planner/functions/casting/CastRulesTest.java | 332 +++++++++++++++++++++ .../planner/expressions/ScalarFunctionsTest.scala | 16 +- 20 files changed, 1117 insertions(+), 275 deletions(-) diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/logical/VarCharType.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/logical/VarCharType.java index 5a71b21..7c73b6c 100644 --- a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/logical/VarCharType.java +++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/logical/VarCharType.java @@ -54,6 +54,8 @@ public final class VarCharType extends LogicalType { public static final int DEFAULT_LENGTH = 1; + public static final VarCharType STRING_TYPE = new VarCharType(MAX_LENGTH); + private static final String FORMAT = "VARCHAR(%d)"; private static final String MAX_FORMAT = "STRING"; diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/casting/ArrayToStringCastRule.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/casting/ArrayToStringCastRule.java index e470739..57f9e48 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/casting/ArrayToStringCastRule.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/casting/ArrayToStringCastRule.java @@ -23,6 +23,7 @@ import org.apache.flink.table.types.logical.ArrayType; import org.apache.flink.table.types.logical.LogicalType; import org.apache.flink.table.types.logical.LogicalTypeFamily; import org.apache.flink.table.types.logical.LogicalTypeRoot; +import org.apache.flink.table.types.logical.utils.LogicalTypeChecks; import static org.apache.flink.table.planner.codegen.CodeGenUtils.className; import static org.apache.flink.table.planner.codegen.CodeGenUtils.newName; @@ -32,6 +33,9 @@ import static org.apache.flink.table.planner.functions.casting.CastRuleUtils.NUL import static org.apache.flink.table.planner.functions.casting.CastRuleUtils.constructorCall; import static org.apache.flink.table.planner.functions.casting.CastRuleUtils.methodCall; import static org.apache.flink.table.planner.functions.casting.CastRuleUtils.strLiteral; +import static org.apache.flink.table.planner.functions.casting.CharVarCharTrimPadCastRule.couldTrim; +import static org.apache.flink.table.planner.functions.casting.CharVarCharTrimPadCastRule.stringExceedsLength; +import static org.apache.flink.table.types.logical.VarCharType.STRING_TYPE; /** {@link LogicalTypeRoot#ARRAY} to {@link LogicalTypeFamily#CHARACTER_STRING} cast rule. */ class ArrayToStringCastRule extends AbstractNullAwareCodeGeneratorCastRule<ArrayData, String> { @@ -51,28 +55,54 @@ class ArrayToStringCastRule extends AbstractNullAwareCodeGeneratorCastRule<Array .build()); } - /* Example generated code for ARRAY<INT>: + /* Example generated code for ARRAY<INT> -> CHAR(10) isNull$0 = _myInputIsNull; if (!isNull$0) { builder$1.setLength(0); builder$1.append("["); - for (int i$2 = 0; i$2 < _myInput.size(); i$2++) { - if (i$2 != 0) { + for (int i$3 = 0; i$3 < _myInput.size(); i$3++) { + if (builder$1.length() > 10) { + break; + } + if (i$3 != 0) { builder$1.append(", "); } - int element$3 = -1; - boolean elementIsNull$4 = _myInput.isNullAt(i$2); - if (!elementIsNull$4) { - element$3 = _myInput.getInt(i$2); - result$2 = org.apache.flink.table.data.binary.BinaryStringData.fromString("" + element$3); - builder$1.append(result$2); + int element$4 = -1; + boolean elementIsNull$5 = _myInput.isNullAt(i$3); + if (!elementIsNull$5) { + element$4 = _myInput.getInt(i$3); + isNull$2 = false; + if (!isNull$2) { + result$3 = org.apache.flink.table.data.binary.BinaryStringData.fromString("" + element$4); + isNull$2 = result$3 == null; + } else { + result$3 = org.apache.flink.table.data.binary.BinaryStringData.EMPTY_UTF8; + } + builder$1.append(result$3); } else { builder$1.append("null"); } } builder$1.append("]"); - result$1 = org.apache.flink.table.data.binary.BinaryStringData.fromString(builder$1.toString()); + java.lang.String resultString$2; + resultString$2 = builder$1.toString(); + if (builder$1.length() > 10) { + resultString$2 = builder$1.substring(0, java.lang.Math.min(builder$1.length(), 10)); + } else { + if (resultString$2.length() < 10) { + int padLength$6; + padLength$6 = 10 - resultString$2.length(); + java.lang.StringBuilder sbPadding$7; + sbPadding$7 = new java.lang.StringBuilder(); + for (int i$8 = 0; i$8 < padLength$6; i$8++) { + sbPadding$7.append(" "); + } + resultString$2 = resultString$2 + sbPadding$7.toString(); + } + } + result$1 = org.apache.flink.table.data.binary.BinaryStringData.fromString(resultString$2); + isNull$0 = result$1 == null; } else { result$1 = org.apache.flink.table.data.binary.BinaryStringData.EMPTY_UTF8; } @@ -91,74 +121,93 @@ class ArrayToStringCastRule extends AbstractNullAwareCodeGeneratorCastRule<Array context.declareClassField( className(StringBuilder.class), builderTerm, constructorCall(StringBuilder.class)); - return new CastRuleUtils.CodeWriter() - .stmt(methodCall(builderTerm, "setLength", 0)) - .stmt(methodCall(builderTerm, "append", strLiteral("["))) - .forStmt( - methodCall(inputTerm, "size"), - (indexTerm, loopBodyWriter) -> { - String elementTerm = newName("element"); - String elementIsNullTerm = newName("elementIsNull"); + final String resultStringTerm = newName("resultString"); + final int length = LogicalTypeChecks.getLength(targetLogicalType); + + CastRuleUtils.CodeWriter writer = + new CastRuleUtils.CodeWriter() + .stmt(methodCall(builderTerm, "setLength", 0)) + .stmt(methodCall(builderTerm, "append", strLiteral("["))) + .forStmt( + methodCall(inputTerm, "size"), + (indexTerm, loopBodyWriter) -> { + String elementTerm = newName("element"); + String elementIsNullTerm = newName("elementIsNull"); - CastCodeBlock codeBlock = - CastRuleProvider.generateCodeBlock( - context, - elementTerm, - "false", - // Null check is done at the array access level - innerInputType.copy(false), - targetLogicalType); + CastCodeBlock codeBlock = + CastRuleProvider.generateCodeBlock( + context, + elementTerm, + "false", + // Null check is done at the array + // access level + innerInputType.copy(false), + STRING_TYPE); - loopBodyWriter - // Write the comma - .ifStmt( - indexTerm + " != 0", - thenBodyWriter -> - thenBodyWriter.stmt( - methodCall( - builderTerm, - "append", - strLiteral(", ")))) - // Extract element from array - .declPrimitiveStmt(innerInputType, elementTerm) - .declStmt( - boolean.class, - elementIsNullTerm, - methodCall(inputTerm, "isNullAt", indexTerm)) - .ifStmt( - "!" + elementIsNullTerm, - thenBodyWriter -> - thenBodyWriter - // If element not null, extract it and - // execute the cast - .assignStmt( - elementTerm, - rowFieldReadAccess( - indexTerm, - inputTerm, - innerInputType)) - .append(codeBlock) - .stmt( + if (!context.legacyBehaviour() && couldTrim(length)) { + // Break if the target length is already exceeded + loopBodyWriter.ifStmt( + stringExceedsLength(builderTerm, length), + thenBodyWriter -> thenBodyWriter.stmt("break")); + } + loopBodyWriter + // Write the comma + .ifStmt( + indexTerm + " != 0", + thenBodyWriter -> + thenBodyWriter.stmt( + methodCall( + builderTerm, + "append", + strLiteral(", ")))) + // Extract element from array + .declPrimitiveStmt(innerInputType, elementTerm) + .declStmt( + boolean.class, + elementIsNullTerm, + methodCall(inputTerm, "isNullAt", indexTerm)) + .ifStmt( + "!" + elementIsNullTerm, + thenBodyWriter -> + thenBodyWriter + // If element not null, + // extract it and + // execute the cast + .assignStmt( + elementTerm, + rowFieldReadAccess( + indexTerm, + inputTerm, + innerInputType)) + .append(codeBlock) + .stmt( + methodCall( + builderTerm, + "append", + codeBlock + .getReturnTerm())), + elseBodyWriter -> + // If element is null, just + // write NULL + elseBodyWriter.stmt( methodCall( builderTerm, "append", - codeBlock - .getReturnTerm())), - elseBodyWriter -> - // If element is null, just write NULL - elseBodyWriter.stmt( - methodCall( - builderTerm, - "append", - NULL_STR_LITERAL))); - }) - .stmt(methodCall(builderTerm, "append", strLiteral("]"))) + NULL_STR_LITERAL))); + }) + .stmt(methodCall(builderTerm, "append", strLiteral("]"))); + return CharVarCharTrimPadCastRule.padAndTrimStringIfNeeded( + writer, + targetLogicalType, + context.legacyBehaviour(), + length, + resultStringTerm, + builderTerm) // Assign the result value .assignStmt( returnVariable, CastRuleUtils.staticCall( - BINARY_STRING_DATA_FROM_STRING(), - methodCall(builderTerm, "toString"))) + BINARY_STRING_DATA_FROM_STRING(), resultStringTerm)) .toString(); } } diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/casting/BinaryToStringCastRule.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/casting/BinaryToStringCastRule.java index fd95948..126e3c0 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/casting/BinaryToStringCastRule.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/casting/BinaryToStringCastRule.java @@ -25,6 +25,7 @@ import java.nio.charset.StandardCharsets; import static org.apache.flink.table.planner.functions.casting.CastRuleUtils.accessStaticField; import static org.apache.flink.table.planner.functions.casting.CastRuleUtils.constructorCall; +import static org.apache.flink.table.types.logical.VarCharType.STRING_TYPE; /** * {@link LogicalTypeFamily#BINARY_STRING} to {@link LogicalTypeFamily#CHARACTER_STRING} cast rule. @@ -37,7 +38,7 @@ class BinaryToStringCastRule extends AbstractCharacterFamilyTargetRule<byte[]> { super( CastRulePredicate.builder() .input(LogicalTypeFamily.BINARY_STRING) - .target(LogicalTypeFamily.CHARACTER_STRING) + .target(STRING_TYPE) .build()); } diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/casting/BooleanToStringCastRule.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/casting/BooleanToStringCastRule.java index ae95571..0d3ab13 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/casting/BooleanToStringCastRule.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/casting/BooleanToStringCastRule.java @@ -24,6 +24,7 @@ import org.apache.flink.table.types.logical.LogicalTypeRoot; import static org.apache.flink.table.planner.functions.casting.CastRuleUtils.EMPTY_STR_LITERAL; import static org.apache.flink.table.planner.functions.casting.CastRuleUtils.stringConcat; +import static org.apache.flink.table.types.logical.VarCharType.STRING_TYPE; /** {@link LogicalTypeRoot#BOOLEAN} to {@link LogicalTypeFamily#CHARACTER_STRING} cast rule. */ class BooleanToStringCastRule extends AbstractCharacterFamilyTargetRule<Object> { @@ -34,7 +35,7 @@ class BooleanToStringCastRule extends AbstractCharacterFamilyTargetRule<Object> super( CastRulePredicate.builder() .input(LogicalTypeRoot.BOOLEAN) - .target(LogicalTypeFamily.CHARACTER_STRING) + .target(STRING_TYPE) .build()); } diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/casting/CastRulePredicate.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/casting/CastRulePredicate.java index 40555a8..3b3c67f 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/casting/CastRulePredicate.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/casting/CastRulePredicate.java @@ -35,16 +35,17 @@ import java.util.function.BiPredicate; * of input and target type using this class. In particular, a rule is applied if: * * <ol> - * <li>{@link #getTargetTypes()} includes the {@link LogicalTypeRoot} of target type and either + * <li>{@link #getTargetTypeRoots()} includes the {@link LogicalTypeRoot} of target type and + * either * <ol> - * <li>{@link #getInputTypes()} includes the {@link LogicalTypeRoot} of input type or + * <li>{@link #getInputTypeRoots()} includes the {@link LogicalTypeRoot} of input type or * <li>{@link #getInputTypeFamilies()} includes one of the {@link LogicalTypeFamily} of * input type * </ol> * <li>Or {@link #getTargetTypeFamilies()} includes one of the {@link LogicalTypeFamily} of target * type and either * <ol> - * <li>{@link #getInputTypes()} includes the {@link LogicalTypeRoot} of input type or + * <li>{@link #getInputTypeRoots()} includes the {@link LogicalTypeRoot} of input type or * <li>{@link #getInputTypeFamilies()} includes one of the {@link LogicalTypeFamily} of * input type * </ol> @@ -59,8 +60,10 @@ import java.util.function.BiPredicate; @Internal public class CastRulePredicate { - private final Set<LogicalTypeRoot> inputTypes; - private final Set<LogicalTypeRoot> targetTypes; + private final Set<LogicalType> targetTypes; + + private final Set<LogicalTypeRoot> inputTypeRoots; + private final Set<LogicalTypeRoot> targetTypeRoots; private final Set<LogicalTypeFamily> inputTypeFamilies; private final Set<LogicalTypeFamily> targetTypeFamilies; @@ -68,24 +71,30 @@ public class CastRulePredicate { private final BiPredicate<LogicalType, LogicalType> customPredicate; private CastRulePredicate( - Set<LogicalTypeRoot> inputTypes, - Set<LogicalTypeRoot> targetTypes, + Set<LogicalType> targetTypes, + Set<LogicalTypeRoot> inputTypeRoots, + Set<LogicalTypeRoot> targetTypeRoots, Set<LogicalTypeFamily> inputTypeFamilies, Set<LogicalTypeFamily> targetTypeFamilies, BiPredicate<LogicalType, LogicalType> customPredicate) { - this.inputTypes = inputTypes; this.targetTypes = targetTypes; + this.inputTypeRoots = inputTypeRoots; + this.targetTypeRoots = targetTypeRoots; this.inputTypeFamilies = inputTypeFamilies; this.targetTypeFamilies = targetTypeFamilies; this.customPredicate = customPredicate; } - public Set<LogicalTypeRoot> getInputTypes() { - return inputTypes; + public Set<LogicalType> getTargetTypes() { + return targetTypes; } - public Set<LogicalTypeRoot> getTargetTypes() { - return targetTypes; + public Set<LogicalTypeRoot> getInputTypeRoots() { + return inputTypeRoots; + } + + public Set<LogicalTypeRoot> getTargetTypeRoots() { + return targetTypeRoots; } public Set<LogicalTypeFamily> getInputTypeFamilies() { @@ -106,20 +115,26 @@ public class CastRulePredicate { /** Builder for the {@link CastRulePredicate}. */ public static class Builder { - private final Set<LogicalTypeRoot> inputTypes = new HashSet<>(); - private final Set<LogicalTypeRoot> targetTypes = new HashSet<>(); + private final Set<LogicalTypeRoot> inputTypeRoots = new HashSet<>(); + private final Set<LogicalTypeRoot> targetTypeRoots = new HashSet<>(); + private final Set<LogicalType> targetTypes = new HashSet<>(); private final Set<LogicalTypeFamily> inputTypeFamilies = new HashSet<>(); private final Set<LogicalTypeFamily> targetTypeFamilies = new HashSet<>(); private BiPredicate<LogicalType, LogicalType> customPredicate; - public Builder input(LogicalTypeRoot inputType) { - inputTypes.add(inputType); + public Builder input(LogicalTypeRoot inputTypeRoot) { + inputTypeRoots.add(inputTypeRoot); + return this; + } + + public Builder target(LogicalTypeRoot outputTypeRoot) { + targetTypeRoots.add(outputTypeRoot); return this; } - public Builder target(LogicalTypeRoot outputType) { + public Builder target(LogicalType outputType) { targetTypes.add(outputType); return this; } @@ -141,8 +156,9 @@ public class CastRulePredicate { public CastRulePredicate build() { return new CastRulePredicate( - Collections.unmodifiableSet(inputTypes), Collections.unmodifiableSet(targetTypes), + Collections.unmodifiableSet(inputTypeRoots), + Collections.unmodifiableSet(targetTypeRoots), Collections.unmodifiableSet(inputTypeFamilies), Collections.unmodifiableSet(targetTypeFamilies), customPredicate); diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/casting/CastRuleProvider.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/casting/CastRuleProvider.java index afb91b7..b2d5fd1 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/casting/CastRuleProvider.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/casting/CastRuleProvider.java @@ -78,6 +78,7 @@ public class CastRuleProvider { .addRule(ArrayToArrayCastRule.INSTANCE) .addRule(RowToRowCastRule.INSTANCE) // Special rules + .addRule(CharVarCharTrimPadCastRule.INSTANCE) .addRule(NullToStringCastRule.INSTANCE) .addRule(IdentityCastRule.INSTANCE); } @@ -148,10 +149,20 @@ public class CastRuleProvider { private CastRuleProvider addRule(CastRule<?, ?> rule) { CastRulePredicate predicate = rule.getPredicateDefinition(); - for (LogicalTypeRoot targetTypeRoot : predicate.getTargetTypes()) { + for (LogicalType targetType : predicate.getTargetTypes()) { + final Map<Object, CastRule<?, ?>> map = + rules.computeIfAbsent(targetType, k -> new HashMap<>()); + for (LogicalTypeRoot inputTypeRoot : predicate.getInputTypeRoots()) { + map.put(inputTypeRoot, rule); + } + for (LogicalTypeFamily inputTypeFamily : predicate.getInputTypeFamilies()) { + map.put(inputTypeFamily, rule); + } + } + for (LogicalTypeRoot targetTypeRoot : predicate.getTargetTypeRoots()) { final Map<Object, CastRule<?, ?>> map = rules.computeIfAbsent(targetTypeRoot, k -> new HashMap<>()); - for (LogicalTypeRoot inputTypeRoot : predicate.getInputTypes()) { + for (LogicalTypeRoot inputTypeRoot : predicate.getInputTypeRoots()) { map.put(inputTypeRoot, rule); } for (LogicalTypeFamily inputTypeFamily : predicate.getInputTypeFamilies()) { @@ -161,7 +172,7 @@ public class CastRuleProvider { for (LogicalTypeFamily targetTypeFamily : predicate.getTargetTypeFamilies()) { final Map<Object, CastRule<?, ?>> map = rules.computeIfAbsent(targetTypeFamily, k -> new HashMap<>()); - for (LogicalTypeRoot inputTypeRoot : predicate.getInputTypes()) { + for (LogicalTypeRoot inputTypeRoot : predicate.getInputTypeRoots()) { map.put(inputTypeRoot, rule); } for (LogicalTypeFamily inputTypeFamily : predicate.getInputTypeFamilies()) { @@ -182,8 +193,10 @@ public class CastRuleProvider { final Iterator<Object> targetTypeRootFamilyIterator = Stream.<Object>concat( - Stream.of(targetType.getTypeRoot()), - targetType.getTypeRoot().getFamilies().stream()) + Stream.of(targetType), + Stream.<Object>concat( + Stream.of(targetType.getTypeRoot()), + targetType.getTypeRoot().getFamilies().stream())) .iterator(); // Try lookup by target type root/type families diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/casting/CharVarCharTrimPadCastRule.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/casting/CharVarCharTrimPadCastRule.java new file mode 100644 index 0000000..9d87074 --- /dev/null +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/casting/CharVarCharTrimPadCastRule.java @@ -0,0 +1,252 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.planner.functions.casting; + +import org.apache.flink.table.data.StringData; +import org.apache.flink.table.data.binary.BinaryStringData; +import org.apache.flink.table.data.binary.BinaryStringDataUtil; +import org.apache.flink.table.types.logical.LogicalType; +import org.apache.flink.table.types.logical.LogicalTypeFamily; +import org.apache.flink.table.types.logical.LogicalTypeRoot; +import org.apache.flink.table.types.logical.VarCharType; +import org.apache.flink.table.types.logical.utils.LogicalTypeChecks; + +import static org.apache.flink.table.planner.codegen.CodeGenUtils.newName; +import static org.apache.flink.table.planner.functions.casting.CastRuleUtils.constructorCall; +import static org.apache.flink.table.planner.functions.casting.CastRuleUtils.methodCall; +import static org.apache.flink.table.planner.functions.casting.CastRuleUtils.staticCall; +import static org.apache.flink.table.types.logical.VarCharType.STRING_TYPE; + +/** + * Any source type to {@link LogicalTypeFamily#BINARY_STRING} cast rule. + * + * <p>This rule is used for casting from any of the {@link LogicalTypeFamily#PREDEFINED} types to + * {@link LogicalTypeRoot#CHAR} or {@link LogicalTypeRoot#VARCHAR}. It calls the underlying concrete + * matching rule, i.e.: {@link NumericToStringCastRule} to do the actual conversion and then + * performs any necessary trimming or padding so that the length of the result string value matches + * the one specified by the length of the target {@link LogicalTypeRoot#CHAR} or {@link + * LogicalTypeRoot#VARCHAR} type. + */ +class CharVarCharTrimPadCastRule + extends AbstractNullAwareCodeGeneratorCastRule<Object, StringData> { + + static final CharVarCharTrimPadCastRule INSTANCE = new CharVarCharTrimPadCastRule(); + + private CharVarCharTrimPadCastRule() { + super( + CastRulePredicate.builder() + .predicate( + (inputType, targetType) -> + targetType.is(LogicalTypeFamily.CHARACTER_STRING) + && !targetType.equals(STRING_TYPE)) + .build()); + } + + /* Example generated code for STRING() -> CHAR(6) cast + + isNull$0 = _myInputIsNull; + if (!isNull$0) { + if (_myInput.numChars() > 6) { + result$1 = _myInput.substring(0, 6); + } else { + if (_myInput.numChars() < 6) { + int padLength$1; + padLength$1 = 6 - _myInput.numChars(); + org.apache.flink.table.data.binary.BinaryStringData padString$2; + padString$2 = org.apache.flink.table.data.binary.BinaryStringData.blankString(padLength$1); + result$1 = org.apache.flink.table.data.binary.BinaryStringDataUtil.concat(_myInput, padString$2); + } else { + result$1 = _myInput; + } + } + isNull$0 = result$1 == null; + } else { + result$1 = org.apache.flink.table.data.binary.BinaryStringData.EMPTY_UTF8; + } + + */ + @Override + protected String generateCodeBlockInternal( + CodeGeneratorCastRule.Context context, + String inputTerm, + String returnVariable, + LogicalType inputLogicalType, + LogicalType targetLogicalType) { + final int length = LogicalTypeChecks.getLength(targetLogicalType); + CastRule<?, ?> castRule = + CastRuleProvider.resolve(inputLogicalType, VarCharType.STRING_TYPE); + + // Only used for non-Constructed types - for constructed type and RAW, the trimming/padding + // is applied on each individual rule, i.e.: ArrayToStringCastRule, RawToStringCastRule + if (castRule instanceof ExpressionCodeGeneratorCastRule) { + @SuppressWarnings("rawtypes") + final String stringExpr = + ((ExpressionCodeGeneratorCastRule) castRule) + .generateExpression( + context, inputTerm, inputLogicalType, targetLogicalType); + + final CastRuleUtils.CodeWriter writer = new CastRuleUtils.CodeWriter(); + if (context.legacyBehaviour() + || !(couldTrim(length) || couldPad(targetLogicalType, length))) { + return writer.assignStmt(returnVariable, stringExpr).toString(); + } + return writer.ifStmt( + methodCall(stringExpr, "numChars") + " > " + length, + thenWriter -> + thenWriter.assignStmt( + returnVariable, + methodCall(stringExpr, "substring", 0, length)), + elseWriter -> { + if (couldPad(targetLogicalType, length)) { + final String padLength = newName("padLength"); + final String padString = newName("padString"); + elseWriter.ifStmt( + methodCall(stringExpr, "numChars") + " < " + length, + thenInnerWriter -> + thenInnerWriter + .declStmt(int.class, padLength) + .assignStmt( + padLength, + length + + " - " + + methodCall( + stringExpr, + "numChars")) + .declStmt( + BinaryStringData.class, + padString) + .assignStmt( + padString, + staticCall( + BinaryStringData.class, + "blankString", + padLength)) + .assignStmt( + returnVariable, + staticCall( + BinaryStringDataUtil + .class, + "concat", + stringExpr, + padString)), + elseInnerWriter -> + elseInnerWriter.assignStmt( + returnVariable, stringExpr)); + } else { + elseWriter.assignStmt(returnVariable, stringExpr); + } + }) + .toString(); + } else { + throw new IllegalStateException("This is a bug. Please file an issue."); + } + } + + // --------------- + // Shared methods + // --------------- + + static String stringExceedsLength(String strTerm, int targetLength) { + return methodCall(strTerm, "length") + " > " + targetLength; + } + + static String stringShouldPad(String strTerm, int targetLength) { + return methodCall(strTerm, "length") + " < " + targetLength; + } + + static boolean couldTrim(int targetLength) { + return targetLength < VarCharType.MAX_LENGTH; + } + + static boolean couldPad(LogicalType targetType, int targetLength) { + return targetType.is(LogicalTypeRoot.CHAR) && targetLength < VarCharType.MAX_LENGTH; + } + + static CastRuleUtils.CodeWriter padAndTrimStringIfNeeded( + CastRuleUtils.CodeWriter writer, + LogicalType targetType, + boolean legacyBehaviour, + int length, + String resultStringTerm, + String builderTerm) { + writer.declStmt(String.class, resultStringTerm) + .assignStmt(resultStringTerm, methodCall(builderTerm, "toString")); + + // Trim and Pad if needed + if (!legacyBehaviour && (couldTrim(length) || couldPad(targetType, length))) { + writer.ifStmt( + stringExceedsLength(builderTerm, length), + thenWriter -> + thenWriter.assignStmt( + resultStringTerm, + methodCall( + builderTerm, + "substring", + 0, + staticCall( + Math.class, + "min", + methodCall(builderTerm, "length"), + length))), + elseWriter -> + padStringIfNeeded( + elseWriter, + targetType, + legacyBehaviour, + length, + resultStringTerm)); + } + return writer; + } + + static void padStringIfNeeded( + CastRuleUtils.CodeWriter writer, + LogicalType targetType, + boolean legacyBehaviour, + int length, + String returnTerm) { + + // Pad if needed + if (!legacyBehaviour && couldPad(targetType, length)) { + final String padLength = newName("padLength"); + final String sbPadding = newName("sbPadding"); + writer.ifStmt( + stringShouldPad(returnTerm, length), + thenWriter -> + thenWriter + .declStmt(int.class, padLength) + .assignStmt( + padLength, + length + " - " + methodCall(returnTerm, "length")) + .declStmt(StringBuilder.class, sbPadding) + .assignStmt(sbPadding, constructorCall(StringBuilder.class)) + .forStmt( + padLength, + (idx, loopWriter) -> + loopWriter.stmt( + methodCall( + sbPadding, "append", "\" \""))) + .assignStmt( + returnTerm, + returnTerm + + " + " + + methodCall(sbPadding, "toString"))); + } + } +} diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/casting/DateToStringCastRule.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/casting/DateToStringCastRule.java index d4ab3b8..949b1af 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/casting/DateToStringCastRule.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/casting/DateToStringCastRule.java @@ -24,6 +24,7 @@ import org.apache.flink.table.types.logical.LogicalTypeRoot; import static org.apache.flink.table.planner.codegen.calls.BuiltInMethods.UNIX_DATE_TO_STRING; import static org.apache.flink.table.planner.functions.casting.CastRuleUtils.staticCall; +import static org.apache.flink.table.types.logical.VarCharType.STRING_TYPE; /** {@link LogicalTypeRoot#DATE} to {@link LogicalTypeFamily#CHARACTER_STRING} cast rule. */ class DateToStringCastRule extends AbstractCharacterFamilyTargetRule<Long> { @@ -31,11 +32,7 @@ class DateToStringCastRule extends AbstractCharacterFamilyTargetRule<Long> { static final DateToStringCastRule INSTANCE = new DateToStringCastRule(); private DateToStringCastRule() { - super( - CastRulePredicate.builder() - .input(LogicalTypeRoot.DATE) - .target(LogicalTypeFamily.CHARACTER_STRING) - .build()); + super(CastRulePredicate.builder().input(LogicalTypeRoot.DATE).target(STRING_TYPE).build()); } @Override diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/casting/IntervalToStringCastRule.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/casting/IntervalToStringCastRule.java index 8773a93..ba16178 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/casting/IntervalToStringCastRule.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/casting/IntervalToStringCastRule.java @@ -27,6 +27,7 @@ import java.lang.reflect.Method; import static org.apache.flink.table.planner.codegen.calls.BuiltInMethods.INTERVAL_DAY_TIME_TO_STRING; import static org.apache.flink.table.planner.codegen.calls.BuiltInMethods.INTERVAL_YEAR_MONTH_TO_STRING; import static org.apache.flink.table.planner.functions.casting.CastRuleUtils.staticCall; +import static org.apache.flink.table.types.logical.VarCharType.STRING_TYPE; /** {@link LogicalTypeFamily#INTERVAL} to {@link LogicalTypeFamily#CHARACTER_STRING} cast rule. */ class IntervalToStringCastRule extends AbstractCharacterFamilyTargetRule<Object> { @@ -37,7 +38,7 @@ class IntervalToStringCastRule extends AbstractCharacterFamilyTargetRule<Object> super( CastRulePredicate.builder() .input(LogicalTypeFamily.INTERVAL) - .target(LogicalTypeFamily.CHARACTER_STRING) + .target(STRING_TYPE) .build()); } diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/casting/MapAndMultisetToStringCastRule.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/casting/MapAndMultisetToStringCastRule.java index 25a4478..dd03c51 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/casting/MapAndMultisetToStringCastRule.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/casting/MapAndMultisetToStringCastRule.java @@ -25,6 +25,7 @@ import org.apache.flink.table.types.logical.LogicalTypeFamily; import org.apache.flink.table.types.logical.LogicalTypeRoot; import org.apache.flink.table.types.logical.MapType; import org.apache.flink.table.types.logical.MultisetType; +import org.apache.flink.table.types.logical.utils.LogicalTypeChecks; import java.util.function.Consumer; @@ -36,6 +37,9 @@ import static org.apache.flink.table.planner.functions.casting.CastRuleUtils.NUL import static org.apache.flink.table.planner.functions.casting.CastRuleUtils.constructorCall; import static org.apache.flink.table.planner.functions.casting.CastRuleUtils.methodCall; import static org.apache.flink.table.planner.functions.casting.CastRuleUtils.strLiteral; +import static org.apache.flink.table.planner.functions.casting.CharVarCharTrimPadCastRule.couldTrim; +import static org.apache.flink.table.planner.functions.casting.CharVarCharTrimPadCastRule.stringExceedsLength; +import static org.apache.flink.table.types.logical.VarCharType.STRING_TYPE; /** * {@link LogicalTypeRoot#MAP} and {@link LogicalTypeRoot#MULTISET} to {@link @@ -64,7 +68,7 @@ class MapAndMultisetToStringCastRule ((MultisetType) input).getElementType(), target))); } - /* Example generated code for MAP<STRING, INTERVAL MONTH>: + /* Example generated code for MAP<STRING, INTERVAL MONTH> -> CHAR(12): isNull$0 = _myInputIsNull; if (!isNull$0) { @@ -72,31 +76,57 @@ class MapAndMultisetToStringCastRule org.apache.flink.table.data.ArrayData values$3 = _myInput.valueArray(); builder$1.setLength(0); builder$1.append("{"); - for (int i$4 = 0; i$4 < _myInput.size(); i$4++) { - if (i$4 != 0) { + for (int i$5 = 0; i$5 < _myInput.size(); i$5++) { + if (builder$1.length() > 12) { + break; + } + if (i$5 != 0) { builder$1.append(", "); } - org.apache.flink.table.data.binary.BinaryStringData key$5 = org.apache.flink.table.data.binary.BinaryStringData.EMPTY_UTF8; - boolean keyIsNull$6 = keys$2.isNullAt(i$4); - int value$7 = -1; - boolean valueIsNull$8 = values$3.isNullAt(i$4); - if (!keyIsNull$6) { - key$5 = ((org.apache.flink.table.data.binary.BinaryStringData) keys$2.getString(i$4)); - builder$1.append(key$5); + org.apache.flink.table.data.binary.BinaryStringData key$6 = org.apache.flink.table.data.binary.BinaryStringData.EMPTY_UTF8; + boolean keyIsNull$7 = keys$2.isNullAt(i$5); + int value$8 = -1; + boolean valueIsNull$9 = values$3.isNullAt(i$5); + if (!keyIsNull$7) { + key$6 = ((org.apache.flink.table.data.binary.BinaryStringData) keys$2.getString(i$5)); + builder$1.append(key$6); } else { builder$1.append("null"); } builder$1.append("="); - if (!valueIsNull$8) { - value$7 = values$3.getInt(i$4); - result$2 = org.apache.flink.table.data.binary.BinaryStringData.fromString(org.apache.flink.table.utils.DateTimeUtils.intervalYearMonthToString(value$7)); - builder$1.append(result$2); + if (!valueIsNull$9) { + value$8 = values$3.getInt(i$5); + isNull$2 = valueIsNull$9; + if (!isNull$2) { + result$3 = org.apache.flink.table.data.binary.BinaryStringData.fromString("" + value$8); + isNull$2 = result$3 == null; + } else { + result$3 = org.apache.flink.table.data.binary.BinaryStringData.EMPTY_UTF8; + } + builder$1.append(result$3); } else { builder$1.append("null"); } } builder$1.append("}"); - result$1 = org.apache.flink.table.data.binary.BinaryStringData.fromString(builder$1.toString()); + java.lang.String resultString$4; + resultString$4 = builder$1.toString(); + if (builder$1.length() > 12) { + resultString$4 = builder$1.substring(0, java.lang.Math.min(builder$1.length(), 12)); + } else { + if (resultString$.length() < 12) { + int padLength$10; + padLength$10 = 12 - resultString$.length(); + java.lang.StringBuilder sbPadding$11; + sbPadding$11 = new java.lang.StringBuilder(); + for (int i$12 = 0; i$12 < padLength$10; i$12++) { + sbPadding$11.append(" "); + } + resultString$4 = resultString$4 + sbPadding$11.toString(); + } + } + result$1 = org.apache.flink.table.data.binary.BinaryStringData.fromString(resultString$4); + isNull$0 = result$1 == null; } else { result$1 = org.apache.flink.table.data.binary.BinaryStringData.EMPTY_UTF8; } @@ -125,122 +155,154 @@ class MapAndMultisetToStringCastRule final String keyArrayTerm = newName("keys"); final String valueArrayTerm = newName("values"); - return new CastRuleUtils.CodeWriter() - .declStmt(ArrayData.class, keyArrayTerm, methodCall(inputTerm, "keyArray")) - .declStmt(ArrayData.class, valueArrayTerm, methodCall(inputTerm, "valueArray")) - .stmt(methodCall(builderTerm, "setLength", 0)) - .stmt(methodCall(builderTerm, "append", strLiteral("{"))) - .forStmt( - methodCall(inputTerm, "size"), - (indexTerm, loopBodyWriter) -> { - String keyTerm = newName("key"); - String keyIsNullTerm = newName("keyIsNull"); - String valueTerm = newName("value"); - String valueIsNullTerm = newName("valueIsNull"); - - CastCodeBlock keyCast = - CastRuleProvider.generateCodeBlock( - context, - keyTerm, - keyIsNullTerm, - // Null check is done at the key array access level - keyType.copy(false), - targetLogicalType); - CastCodeBlock valueCast = - CastRuleProvider.generateCodeBlock( - context, - valueTerm, - valueIsNullTerm, - // Null check is done at the value array access level - valueType.copy(false), - targetLogicalType); - - Consumer<CastRuleUtils.CodeWriter> appendNonNullValue = - bodyWriter -> - bodyWriter - // If value not null, extract it and - // execute the cast - .assignStmt( - valueTerm, - rowFieldReadAccess( - indexTerm, - valueArrayTerm, - valueType)) - .append(valueCast) - .stmt( - methodCall( - builderTerm, - "append", - valueCast.getReturnTerm())); - loopBodyWriter - // Write the comma - .ifStmt( - indexTerm + " != 0", - thenBodyWriter -> - thenBodyWriter.stmt( - methodCall( - builderTerm, - "append", - strLiteral(", ")))) - // Declare key and values variables - .declPrimitiveStmt(keyType, keyTerm) - .declStmt( - boolean.class, - keyIsNullTerm, - methodCall(keyArrayTerm, "isNullAt", indexTerm)) - .declPrimitiveStmt(valueType, valueTerm) - .declStmt( - boolean.class, - valueIsNullTerm, - methodCall(valueArrayTerm, "isNullAt", indexTerm)) - // Execute casting if inner key/value not null - .ifStmt( - "!" + keyIsNullTerm, - thenBodyWriter -> - thenBodyWriter - // If key not null, extract it and + final String resultStringTerm = newName("resultString"); + final int length = LogicalTypeChecks.getLength(targetLogicalType); + + CastRuleUtils.CodeWriter writer = + new CastRuleUtils.CodeWriter() + .declStmt(ArrayData.class, keyArrayTerm, methodCall(inputTerm, "keyArray")) + .declStmt( + ArrayData.class, + valueArrayTerm, + methodCall(inputTerm, "valueArray")) + .stmt(methodCall(builderTerm, "setLength", 0)) + .stmt(methodCall(builderTerm, "append", strLiteral("{"))) + .forStmt( + methodCall(inputTerm, "size"), + (indexTerm, loopBodyWriter) -> { + String keyTerm = newName("key"); + String keyIsNullTerm = newName("keyIsNull"); + String valueTerm = newName("value"); + String valueIsNullTerm = newName("valueIsNull"); + + CastCodeBlock keyCast = + CastRuleProvider.generateCodeBlock( + context, + keyTerm, + keyIsNullTerm, + // Null check is done at the key array + // access level + keyType.copy(false), + STRING_TYPE); + CastCodeBlock valueCast = + CastRuleProvider.generateCodeBlock( + context, + valueTerm, + valueIsNullTerm, + // Null check is done at the value array + // access level + valueType.copy(false), + STRING_TYPE); + + Consumer<CastRuleUtils.CodeWriter> appendNonNullValue = + bodyWriter -> + bodyWriter + // If value not null, extract it + // and // execute the cast .assignStmt( - keyTerm, + valueTerm, rowFieldReadAccess( indexTerm, - keyArrayTerm, - keyType)) - .append(keyCast) + valueArrayTerm, + valueType)) + .append(valueCast) .stmt( methodCall( builderTerm, "append", - keyCast - .getReturnTerm())), - elseBodyWriter -> - elseBodyWriter.stmt( - methodCall( - builderTerm, - "append", - NULL_STR_LITERAL))) - .stmt(methodCall(builderTerm, "append", strLiteral("="))); - if (inputLogicalType.is(LogicalTypeRoot.MULTISET)) { - appendNonNullValue.accept(loopBodyWriter); - } else { - loopBodyWriter.ifStmt( - "!" + valueIsNullTerm, - appendNonNullValue, - elseBodyWriter -> - elseBodyWriter.stmt( - methodCall( - builderTerm, - "append", - NULL_STR_LITERAL))); - } - }) - .stmt(methodCall(builderTerm, "append", strLiteral("}"))) + valueCast + .getReturnTerm())); + if (!context.legacyBehaviour() && couldTrim(length)) { + loopBodyWriter + // Break if the target length is already + // exceeded + .ifStmt( + stringExceedsLength(builderTerm, length), + thenBodyWriter -> thenBodyWriter.stmt("break")); + } + loopBodyWriter + // Write the comma + .ifStmt( + indexTerm + " != 0", + thenBodyWriter -> + thenBodyWriter.stmt( + methodCall( + builderTerm, + "append", + strLiteral(", ")))) + // Declare key and values variables + .declPrimitiveStmt(keyType, keyTerm) + .declStmt( + boolean.class, + keyIsNullTerm, + methodCall(keyArrayTerm, "isNullAt", indexTerm)) + .declPrimitiveStmt(valueType, valueTerm) + .declStmt( + boolean.class, + valueIsNullTerm, + methodCall( + valueArrayTerm, "isNullAt", indexTerm)) + // Execute casting if inner key/value not null + .ifStmt( + "!" + keyIsNullTerm, + thenBodyWriter -> + thenBodyWriter + // If key not null, + // extract it and + // execute the cast + .assignStmt( + keyTerm, + rowFieldReadAccess( + indexTerm, + keyArrayTerm, + keyType)) + .append(keyCast) + .stmt( + methodCall( + builderTerm, + "append", + keyCast + .getReturnTerm())), + elseBodyWriter -> + elseBodyWriter.stmt( + methodCall( + builderTerm, + "append", + NULL_STR_LITERAL))) + .stmt( + methodCall( + builderTerm, + "append", + strLiteral("="))); + if (inputLogicalType.is(LogicalTypeRoot.MULTISET)) { + appendNonNullValue.accept(loopBodyWriter); + } else { + loopBodyWriter.ifStmt( + "!" + valueIsNullTerm, + appendNonNullValue, + elseBodyWriter -> + elseBodyWriter.stmt( + methodCall( + builderTerm, + "append", + NULL_STR_LITERAL))); + } + }) + .stmt(methodCall(builderTerm, "append", strLiteral("}"))); + + return CharVarCharTrimPadCastRule.padAndTrimStringIfNeeded( + writer, + targetLogicalType, + context.legacyBehaviour(), + length, + resultStringTerm, + builderTerm) // Assign the result value .assignStmt( returnVariable, CastRuleUtils.staticCall( - BINARY_STRING_DATA_FROM_STRING(), - methodCall(builderTerm, "toString"))) + BINARY_STRING_DATA_FROM_STRING(), resultStringTerm)) .toString(); } } diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/casting/NumericToStringCastRule.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/casting/NumericToStringCastRule.java index d83741e..c6a03c2 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/casting/NumericToStringCastRule.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/casting/NumericToStringCastRule.java @@ -23,6 +23,7 @@ import org.apache.flink.table.types.logical.LogicalTypeFamily; import static org.apache.flink.table.planner.functions.casting.CastRuleUtils.EMPTY_STR_LITERAL; import static org.apache.flink.table.planner.functions.casting.CastRuleUtils.stringConcat; +import static org.apache.flink.table.types.logical.VarCharType.STRING_TYPE; /** {@link LogicalTypeFamily#NUMERIC} to {@link LogicalTypeFamily#CHARACTER_STRING} cast rule. */ class NumericToStringCastRule extends AbstractCharacterFamilyTargetRule<Object> { @@ -33,7 +34,7 @@ class NumericToStringCastRule extends AbstractCharacterFamilyTargetRule<Object> super( CastRulePredicate.builder() .input(LogicalTypeFamily.NUMERIC) - .target(LogicalTypeFamily.CHARACTER_STRING) + .target(STRING_TYPE) .build()); } diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/casting/RawToStringCastRule.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/casting/RawToStringCastRule.java index 3301f5f..454a7ab 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/casting/RawToStringCastRule.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/casting/RawToStringCastRule.java @@ -18,9 +18,11 @@ package org.apache.flink.table.planner.functions.casting; +import org.apache.flink.table.planner.codegen.CodeGenUtils; import org.apache.flink.table.types.logical.LogicalType; import org.apache.flink.table.types.logical.LogicalTypeFamily; import org.apache.flink.table.types.logical.LogicalTypeRoot; +import org.apache.flink.table.types.logical.utils.LogicalTypeChecks; import static org.apache.flink.table.codesplit.CodeSplitUtil.newName; import static org.apache.flink.table.planner.codegen.calls.BuiltInMethods.BINARY_STRING_DATA_FROM_STRING; @@ -39,6 +41,38 @@ class RawToStringCastRule extends AbstractNullAwareCodeGeneratorCastRule<Object, .build()); } + /* Example RAW(LocalDateTime.class) -> CHAR(12) + + isNull$0 = _myInputIsNull; + if (!isNull$0) { + java.lang.Object deserializedObj$0 = _myInput.toObject(typeSerializer$2); + if (deserializedObj$0 != null) { + java.lang.String resultString$1; + resultString$1 = deserializedObj$0.toString().toString(); + if (deserializedObj$0.toString().length() > 12) { + resultString$1 = deserializedObj$0.toString().substring(0, java.lang.Math.min(deserializedObj$0.toString().length(), 12)); + } else { + if (resultString$1.length() < 12) { + int padLength$2; + padLength$2 = 12 - resultString$1.length(); + java.lang.StringBuilder sbPadding$3; + sbPadding$3 = new java.lang.StringBuilder(); + for (int i$4 = 0; i$4 < padLength$2; i$4++) { + sbPadding$3.append(" "); + } + resultString$1 = resultString$1 + sbPadding$3.toString(); + } + } + result$1 = org.apache.flink.table.data.binary.BinaryStringData.fromString(resultString$1); + } else { + result$1 = null; + } + isNull$0 = result$1 == null; + } else { + result$1 = org.apache.flink.table.data.binary.BinaryStringData.EMPTY_UTF8; + } + + */ @Override protected String generateCodeBlockInternal( CodeGeneratorCastRule.Context context, @@ -49,6 +83,9 @@ class RawToStringCastRule extends AbstractNullAwareCodeGeneratorCastRule<Object, final String typeSerializer = context.declareTypeSerializer(inputLogicalType); final String deserializedObjTerm = newName("deserializedObj"); + final String resultStringTerm = CodeGenUtils.newName("resultString"); + final int length = LogicalTypeChecks.getLength(targetLogicalType); + return new CastRuleUtils.CodeWriter() .declStmt( Object.class, @@ -57,11 +94,18 @@ class RawToStringCastRule extends AbstractNullAwareCodeGeneratorCastRule<Object, .ifStmt( deserializedObjTerm + " != null", thenWriter -> - thenWriter.assignStmt( - returnVariable, - CastRuleUtils.staticCall( - BINARY_STRING_DATA_FROM_STRING(), - methodCall(deserializedObjTerm, "toString"))), + CharVarCharTrimPadCastRule.padAndTrimStringIfNeeded( + thenWriter, + targetLogicalType, + context.legacyBehaviour(), + length, + resultStringTerm, + methodCall(deserializedObjTerm, "toString")) + .assignStmt( + returnVariable, + CastRuleUtils.staticCall( + BINARY_STRING_DATA_FROM_STRING(), + resultStringTerm)), elseWriter -> elseWriter.assignStmt(returnVariable, "null")) .toString(); } diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/casting/RowToStringCastRule.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/casting/RowToStringCastRule.java index 3c38f5d..d206618 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/casting/RowToStringCastRule.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/casting/RowToStringCastRule.java @@ -22,6 +22,7 @@ import org.apache.flink.table.data.ArrayData; import org.apache.flink.table.types.logical.LogicalType; import org.apache.flink.table.types.logical.LogicalTypeFamily; import org.apache.flink.table.types.logical.LogicalTypeRoot; +import org.apache.flink.table.types.logical.VarCharType; import org.apache.flink.table.types.logical.utils.LogicalTypeChecks; import java.util.List; @@ -54,32 +55,55 @@ class RowToStringCastRule extends AbstractNullAwareCodeGeneratorCastRule<ArrayDa .allMatch(fieldType -> CastRuleProvider.exists(fieldType, target)); } - /* Example generated code for ROW<`f0` INT, `f1` STRING>: + /* Example generated code for ROW<`f0` INT, `f1` STRING> -> CHAR(12): isNull$0 = _myInputIsNull; if (!isNull$0) { builder$1.setLength(0); builder$1.append("("); - int f0Value$2 = -1; - boolean f0IsNull$3 = _myInput.isNullAt(0); - if (!f0IsNull$3) { - f0Value$2 = _myInput.getInt(0); - result$2 = org.apache.flink.table.data.binary.BinaryStringData.fromString("" + f0Value$2); - builder$1.append(result$2); + int f0Value$3 = -1; + boolean f0IsNull$4 = _myInput.isNullAt(0); + if (!f0IsNull$4) { + f0Value$3 = _myInput.getInt(0); + isNull$2 = f0IsNull$4; + if (!isNull$2) { + result$3 = org.apache.flink.table.data.binary.BinaryStringData.fromString("" + f0Value$3); + isNull$2 = result$3 == null; + } else { + result$3 = org.apache.flink.table.data.binary.BinaryStringData.EMPTY_UTF8; + } + builder$1.append(result$3); } else { builder$1.append("null"); } - builder$1.append(","); - org.apache.flink.table.data.binary.BinaryStringData f1Value$4 = org.apache.flink.table.data.binary.BinaryStringData.EMPTY_UTF8; - boolean f1IsNull$5 = _myInput.isNullAt(1); - if (!f1IsNull$5) { - f1Value$4 = ((org.apache.flink.table.data.binary.BinaryStringData) _myInput.getString(1)); - builder$1.append(f1Value$4); + builder$1.append(", "); + org.apache.flink.table.data.binary.BinaryStringData f1Value$5 = org.apache.flink.table.data.binary.BinaryStringData.EMPTY_UTF8; + boolean f1IsNull$6 = _myInput.isNullAt(1); + if (!f1IsNull$6) { + f1Value$5 = ((org.apache.flink.table.data.binary.BinaryStringData) _myInput.getString(1)); + builder$1.append(f1Value$5); } else { builder$1.append("null"); } builder$1.append(")"); - result$1 = org.apache.flink.table.data.binary.BinaryStringData.fromString(builder$1.toString()); + java.lang.String resultString$2; + resultString$2 = builder$1.toString(); + if (builder$1.length() > 12) { + resultString$2 = builder$1.substring(0, java.lang.Math.min(builder$1.length(), 12)); + } else { + if (resultString$2.length() < 12) { + int padLength$7; + padLength$7 = 12 - resultString$2.length(); + java.lang.StringBuilder sbPadding$8; + sbPadding$8 = new java.lang.StringBuilder(); + for (int i$9 = 0; i$9 < padLength$7; i$9++) { + sbPadding$8.append(" "); + } + resultString$2 = resultString$2 + sbPadding$8.toString(); + } + } + result$1 = org.apache.flink.table.data.binary.BinaryStringData.fromString(resultString$2); + isNull$0 = result$1 == null; } else { result$1 = org.apache.flink.table.data.binary.BinaryStringData.EMPTY_UTF8; } @@ -98,6 +122,13 @@ class RowToStringCastRule extends AbstractNullAwareCodeGeneratorCastRule<ArrayDa context.declareClassField( className(StringBuilder.class), builderTerm, constructorCall(StringBuilder.class)); + final String resultStringTerm = newName("resultString"); + final int length = LogicalTypeChecks.getLength(targetLogicalType); + final LogicalType targetTypeForElementCast = + targetLogicalType.is(LogicalTypeFamily.CHARACTER_STRING) + ? VarCharType.STRING_TYPE + : targetLogicalType; + final CastRuleUtils.CodeWriter writer = new CastRuleUtils.CodeWriter() .stmt(methodCall(builderTerm, "setLength", 0)) @@ -117,7 +148,7 @@ class RowToStringCastRule extends AbstractNullAwareCodeGeneratorCastRule<ArrayDa fieldIsNullTerm, // Null check is done at the row access level fieldType.copy(false), - targetLogicalType); + targetTypeForElementCast); // Write the comma if (fieldIndex != 0) { @@ -154,18 +185,23 @@ class RowToStringCastRule extends AbstractNullAwareCodeGeneratorCastRule<ArrayDa methodCall(builderTerm, "append", NULL_STR_LITERAL))); } - writer.stmt(methodCall(builderTerm, "append", strLiteral(")"))) + writer.stmt(methodCall(builderTerm, "append", strLiteral(")"))); + return CharVarCharTrimPadCastRule.padAndTrimStringIfNeeded( + writer, + targetLogicalType, + context.legacyBehaviour(), + length, + resultStringTerm, + builderTerm) // Assign the result value .assignStmt( returnVariable, CastRuleUtils.staticCall( - BINARY_STRING_DATA_FROM_STRING(), - methodCall(builderTerm, "toString"))); - - return writer.toString(); + BINARY_STRING_DATA_FROM_STRING(), resultStringTerm)) + .toString(); } - private String getDelimiter(CodeGeneratorCastRule.Context context) { + private static String getDelimiter(CodeGeneratorCastRule.Context context) { final String comma; if (context.legacyBehaviour()) { comma = strLiteral(","); diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/casting/TimeToStringCastRule.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/casting/TimeToStringCastRule.java index cc07a79..e9bdc6b 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/casting/TimeToStringCastRule.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/casting/TimeToStringCastRule.java @@ -24,6 +24,7 @@ import org.apache.flink.table.types.logical.LogicalTypeRoot; import org.apache.flink.table.types.logical.utils.LogicalTypeChecks; import static org.apache.flink.table.planner.codegen.calls.BuiltInMethods.UNIX_TIME_TO_STRING; +import static org.apache.flink.table.types.logical.VarCharType.STRING_TYPE; /** * {@link LogicalTypeRoot#TIME_WITHOUT_TIME_ZONE} to {@link LogicalTypeFamily#CHARACTER_STRING} cast @@ -37,7 +38,7 @@ class TimeToStringCastRule extends AbstractCharacterFamilyTargetRule<Long> { super( CastRulePredicate.builder() .input(LogicalTypeRoot.TIME_WITHOUT_TIME_ZONE) - .target(LogicalTypeFamily.CHARACTER_STRING) + .target(STRING_TYPE) .build()); } diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/casting/TimestampToStringCastRule.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/casting/TimestampToStringCastRule.java index 965b378..0385d44 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/casting/TimestampToStringCastRule.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/casting/TimestampToStringCastRule.java @@ -29,6 +29,7 @@ import org.apache.calcite.avatica.util.DateTimeUtils; import static org.apache.flink.table.planner.codegen.calls.BuiltInMethods.TIMESTAMP_TO_STRING_TIME_ZONE; import static org.apache.flink.table.planner.functions.casting.CastRuleUtils.accessStaticField; import static org.apache.flink.table.planner.functions.casting.CastRuleUtils.staticCall; +import static org.apache.flink.table.types.logical.VarCharType.STRING_TYPE; /** {@link LogicalTypeFamily#TIMESTAMP} to {@link LogicalTypeFamily#CHARACTER_STRING} cast rule. */ class TimestampToStringCastRule extends AbstractCharacterFamilyTargetRule<TimestampData> { @@ -39,7 +40,7 @@ class TimestampToStringCastRule extends AbstractCharacterFamilyTargetRule<Timest super( CastRulePredicate.builder() .input(LogicalTypeFamily.TIMESTAMP) - .target(LogicalTypeFamily.CHARACTER_STRING) + .target(STRING_TYPE) .build()); } diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/calls/IfCallGen.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/calls/IfCallGen.scala index 5fe1dd1..9eec6b5 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/calls/IfCallGen.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/calls/IfCallGen.scala @@ -21,7 +21,7 @@ package org.apache.flink.table.planner.codegen.calls import org.apache.flink.table.planner.codegen.CodeGenUtils.{className, primitiveDefaultValue, primitiveTypeTermForType} import org.apache.flink.table.planner.codegen.calls.ScalarOperatorGens.toCodegenCastContext import org.apache.flink.table.planner.codegen.{CodeGeneratorContext, GeneratedExpression} -import org.apache.flink.table.planner.functions.casting.{CastRuleProvider, ExpressionCodeGeneratorCastRule} +import org.apache.flink.table.planner.functions.casting.{CastCodeBlock, CastRuleProvider, CodeGeneratorCastRule, ExpressionCodeGeneratorCastRule} import org.apache.flink.table.types.logical.LogicalType /** @@ -54,18 +54,20 @@ class IfCallGen() extends CallGenerator { val resultCode = s""" |// --- Start code generated by ${className[IfCallGen]} + |${castedResultTerm1.getCode} + |${castedResultTerm2.getCode} |${operands.head.code} |$resultTerm = $resultDefault; |if (${operands.head.resultTerm}) { | ${operands(1).code} | if (!${operands(1).nullTerm}) { - | $resultTerm = $castedResultTerm1; + | $resultTerm = ${castedResultTerm1.getReturnTerm}; | } | $nullTerm = ${operands(1).nullTerm}; |} else { | ${operands(2).code} | if (!${operands(2).nullTerm}) { - | $resultTerm = $castedResultTerm2; + | $resultTerm = ${castedResultTerm2.getReturnTerm}; | } | $nullTerm = ${operands(2).nullTerm}; |} @@ -80,13 +82,24 @@ class IfCallGen() extends CallGenerator { * or null if no casting can be performed */ private def normalizeArgument( - ctx: CodeGeneratorContext, expr: GeneratedExpression, targetType: LogicalType): String = { + ctx: CodeGeneratorContext, + expr: GeneratedExpression, + targetType: LogicalType): CastCodeBlock = { + val rule = CastRuleProvider.resolve(expr.resultType, targetType) rule match { case codeGeneratorCastRule: ExpressionCodeGeneratorCastRule[_, _] => - codeGeneratorCastRule.generateExpression( + CastCodeBlock.withoutCode(codeGeneratorCastRule.generateExpression( + toCodegenCastContext(ctx), + expr.resultTerm, + expr.resultType, + targetType + ), expr.nullTerm) + case codeGeneratorCastRule: CodeGeneratorCastRule[_, _] => + codeGeneratorCastRule.generateCodeBlock( toCodegenCastContext(ctx), expr.resultTerm, + expr.nullTerm, expr.resultType, targetType ) diff --git a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/functions/CastFunctionITCase.java b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/functions/CastFunctionITCase.java index f21a26d..ae5481b 100644 --- a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/functions/CastFunctionITCase.java +++ b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/functions/CastFunctionITCase.java @@ -19,6 +19,7 @@ package org.apache.flink.table.planner.functions; import org.apache.flink.configuration.Configuration; +import org.apache.flink.table.api.config.ExecutionConfigOptions; import org.apache.flink.table.api.config.TableConfigOptions; import org.apache.flink.table.functions.BuiltInFunctionDefinitions; import org.apache.flink.table.types.AbstractDataType; @@ -70,6 +71,7 @@ import static org.apache.flink.table.api.DataTypes.VARBINARY; import static org.apache.flink.table.api.DataTypes.VARCHAR; import static org.apache.flink.table.api.DataTypes.YEAR; import static org.apache.flink.table.api.Expressions.$; +import static org.apache.flink.table.api.config.ExecutionConfigOptions.LegacyCastBehaviour; /** Tests for {@link BuiltInFunctionDefinitions#CAST}. */ public class CastFunctionITCase extends BuiltInFunctionTestBase { @@ -107,7 +109,11 @@ public class CastFunctionITCase extends BuiltInFunctionTestBase { @Override protected Configuration configuration() { - return super.configuration().set(TableConfigOptions.LOCAL_TIME_ZONE, TEST_TZ.getId()); + return super.configuration() + .set(TableConfigOptions.LOCAL_TIME_ZONE, TEST_TZ.getId()) + .set( + ExecutionConfigOptions.TABLE_EXEC_LEGACY_CAST_BEHAVIOUR, + LegacyCastBehaviour.DISABLED); } @Parameterized.Parameters(name = "{index}: {0}") @@ -125,27 +131,22 @@ public class CastFunctionITCase extends BuiltInFunctionTestBase { CastTestSpecBuilder.testCastTo(CHAR(3)) .fromCase(CHAR(5), null, null) .fromCase(CHAR(3), "foo", "foo") - .fromCase(CHAR(4), "foo", "foo ") - .fromCase(CHAR(4), "foo ", "foo ") .fromCase(VARCHAR(3), "foo", "foo") .fromCase(VARCHAR(5), "foo", "foo") - .fromCase(VARCHAR(5), "foo ", "foo ") - // https://issues.apache.org/jira/browse/FLINK-24413 - Trim to precision - // in this case down to 3 chars - .fromCase(STRING(), "abcdef", "abcdef") // "abc" - .fromCase(DATE(), DEFAULT_DATE, "2021-09-24") // "202" + .fromCase(STRING(), "abcdef", "abc") + .fromCase(DATE(), DEFAULT_DATE, "202") + .build(), + CastTestSpecBuilder.testCastTo(CHAR(5)) + .fromCase(CHAR(5), null, null) + .fromCase(CHAR(3), "foo", "foo ") .build(), CastTestSpecBuilder.testCastTo(VARCHAR(3)) .fromCase(VARCHAR(5), null, null) .fromCase(CHAR(3), "foo", "foo") - .fromCase(CHAR(4), "foo", "foo ") - .fromCase(CHAR(4), "foo ", "foo ") + .fromCase(CHAR(4), "foo", "foo") .fromCase(VARCHAR(3), "foo", "foo") .fromCase(VARCHAR(5), "foo", "foo") - .fromCase(VARCHAR(5), "foo ", "foo ") - // https://issues.apache.org/jira/browse/FLINK-24413 - Trim to precision - // in this case down to 3 chars - .fromCase(STRING(), "abcdef", "abcdef") + .fromCase(STRING(), "abcdef", "abc") .build(), CastTestSpecBuilder.testCastTo(STRING()) .fromCase(STRING(), null, null) diff --git a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/functions/casting/CastRuleProviderTest.java b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/functions/casting/CastRuleProviderTest.java index 5ee289a..97778b9 100644 --- a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/functions/casting/CastRuleProviderTest.java +++ b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/functions/casting/CastRuleProviderTest.java @@ -20,13 +20,16 @@ package org.apache.flink.table.planner.functions.casting; import org.apache.flink.table.catalog.ObjectIdentifier; import org.apache.flink.table.types.logical.ArrayType; +import org.apache.flink.table.types.logical.CharType; import org.apache.flink.table.types.logical.DistinctType; import org.apache.flink.table.types.logical.LogicalType; +import org.apache.flink.table.types.logical.VarCharType; import org.junit.jupiter.api.Test; import static org.apache.flink.table.api.DataTypes.BIGINT; import static org.apache.flink.table.api.DataTypes.INT; +import static org.apache.flink.table.types.logical.VarCharType.STRING_TYPE; import static org.assertj.core.api.Assertions.assertThat; class CastRuleProviderTest { @@ -58,4 +61,20 @@ class CastRuleProviderTest { assertThat(CastRuleProvider.resolve(new ArrayType(INT), new ArrayType(DISTINCT_BIG_INT))) .isSameAs(ArrayToArrayCastRule.INSTANCE); } + + @Test + void testResolvePredefinedToString() { + assertThat(CastRuleProvider.resolve(INT, new VarCharType(10))) + .isSameAs(CharVarCharTrimPadCastRule.INSTANCE); + assertThat(CastRuleProvider.resolve(INT, new CharType(10))) + .isSameAs(CharVarCharTrimPadCastRule.INSTANCE); + assertThat(CastRuleProvider.resolve(INT, STRING_TYPE)) + .isSameAs(NumericToStringCastRule.INSTANCE); + } + + @Test + void testResolveConstructedToString() { + assertThat(CastRuleProvider.resolve(new ArrayType(INT), new VarCharType(10))) + .isSameAs(ArrayToStringCastRule.INSTANCE); + } } diff --git a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/functions/casting/CastRulesTest.java b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/functions/casting/CastRulesTest.java index 99cbfb7..76cd6d9 100644 --- a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/functions/casting/CastRulesTest.java +++ b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/functions/casting/CastRulesTest.java @@ -18,6 +18,7 @@ package org.apache.flink.table.planner.functions.casting; +import org.apache.flink.api.common.typeutils.base.LocalDateSerializer; import org.apache.flink.api.common.typeutils.base.LocalDateTimeSerializer; import org.apache.flink.table.api.TableException; import org.apache.flink.table.data.GenericArrayData; @@ -88,6 +89,7 @@ import static org.apache.flink.table.api.DataTypes.VARCHAR; import static org.apache.flink.table.api.DataTypes.YEAR; import static org.apache.flink.table.data.DecimalData.fromBigDecimal; import static org.apache.flink.table.data.StringData.fromString; +import static org.apache.flink.table.data.binary.BinaryStringData.EMPTY_UTF8; import static org.assertj.core.api.Assertions.assertThat; import static org.junit.jupiter.api.Assertions.assertThrows; @@ -101,6 +103,8 @@ class CastRulesTest { private static final CastRule.Context CET_CONTEXT = CastRule.Context.create(false, CET, Thread.currentThread().getContextClassLoader()); + private static final CastRule.Context CET_CONTEXT_LEGACY = + CastRule.Context.create(true, CET, Thread.currentThread().getContextClassLoader()); private static final byte DEFAULT_POSITIVE_TINY_INT = (byte) 5; private static final byte DEFAULT_NEGATIVE_TINY_INT = (byte) -5; @@ -606,6 +610,334 @@ class CastRulesTest { fromString("c") })), fromString("(10, null, 12:34:56.123, [a, b, c])")), + CastTestSpecBuilder.testCastTo(CHAR(6)) + .fromCase(STRING(), null, EMPTY_UTF8, false) + .fromCase(STRING(), null, EMPTY_UTF8, true) + .fromCase(CHAR(6), fromString("Apache"), fromString("Apache"), false) + .fromCase(CHAR(6), fromString("Apache"), fromString("Apache"), true) + .fromCase(VARCHAR(5), fromString("Flink"), fromString("Flink "), false) + .fromCase(VARCHAR(5), fromString("Flink"), fromString("Flink"), true) + .fromCase(STRING(), fromString("foo"), fromString("foo "), false) + .fromCase(STRING(), fromString("foo"), fromString("foo"), true) + .fromCase(BOOLEAN(), true, fromString("true "), false) + .fromCase(BOOLEAN(), true, fromString("true"), true) + .fromCase(BOOLEAN(), false, fromString("false "), false) + .fromCase(BOOLEAN(), false, fromString("false"), true) + .fromCase( + BINARY(3), + new byte[] {0, 1, 2}, + fromString("\u0000\u0001\u0002 "), + false) + .fromCase( + BINARY(3), + new byte[] {0, 1, 2}, + fromString("\u0000\u0001\u0002"), + true) + .fromCase( + VARBINARY(4), + new byte[] {0, 1, 2, 3}, + fromString("\u0000\u0001\u0002\u0003 "), + false) + .fromCase( + VARBINARY(4), + new byte[] {0, 1, 2, 3}, + fromString("\u0000\u0001\u0002\u0003"), + true) + .fromCase( + BYTES(), + new byte[] {0, 1, 2, 3, 4}, + fromString("\u0000\u0001\u0002\u0003\u0004 "), + false) + .fromCase( + BYTES(), + new byte[] {0, 1, 2, 3, 4}, + fromString("\u0000\u0001\u0002\u0003\u0004"), + true) + .fromCase(TINYINT(), (byte) -125, fromString("-125 "), false) + .fromCase(TINYINT(), (byte) -125, fromString("-125"), true) + .fromCase(SMALLINT(), (short) 32767, fromString("32767 "), false) + .fromCase(SMALLINT(), (short) 32767, fromString("32767"), true) + .fromCase(INT(), -1234, fromString("-1234 "), false) + .fromCase(INT(), -1234, fromString("-1234"), true) + .fromCase(BIGINT(), 12345L, fromString("12345 "), false) + .fromCase(BIGINT(), 12345L, fromString("12345"), true) + .fromCase(FLOAT(), -1.23f, fromString("-1.23 "), false) + .fromCase(FLOAT(), -1.23f, fromString("-1.23"), true) + .fromCase(DOUBLE(), 123.4d, fromString("123.4 "), false) + .fromCase(DOUBLE(), 123.4d, fromString("123.4"), true) + .fromCase(INTERVAL(YEAR()), 84, fromString("+7-00 "), false) + .fromCase(INTERVAL(YEAR()), 84, fromString("+7-00"), true) + .fromCase(INTERVAL(MONTH()), 5, fromString("+0-05 "), false) + .fromCase(INTERVAL(MONTH()), 5, fromString("+0-05"), true), + CastTestSpecBuilder.testCastTo(CHAR(12)) + .fromCase( + ARRAY(INT()), + new GenericArrayData(new int[] {-1, 2, 3}), + fromString("[-1, 2, 3] "), + false) + .fromCase( + ARRAY(INT()), + new GenericArrayData(new int[] {-1, 2, 3}), + fromString("[-1, 2, 3]"), + true) + .fromCase(ARRAY(INT()).nullable(), null, EMPTY_UTF8, false) + .fromCase(ARRAY(INT()).nullable(), null, EMPTY_UTF8, true) + .fromCase( + MAP(STRING(), INT()), + mapData(entry(fromString("a"), 1), entry(fromString("b"), 8)), + fromString("{a=1, b=8} "), + false) + .fromCase( + MAP(STRING(), INT()), + mapData(entry(fromString("a"), 1), entry(fromString("b"), 8)), + fromString("{a=1, b=8}"), + true) + .fromCase( + MAP(STRING(), INTERVAL(MONTH())).nullable(), null, EMPTY_UTF8, true) + .fromCase( + MAP(STRING(), INTERVAL(MONTH())).nullable(), + null, + EMPTY_UTF8, + false) + .fromCase( + MULTISET(STRING()), + mapData(entry(fromString("a"), 1), entry(fromString("b"), 1)), + fromString("{a=1, b=1} "), + false) + .fromCase( + MULTISET(STRING()), + mapData(entry(fromString("a"), 1), entry(fromString("b"), 1)), + fromString("{a=1, b=1}"), + true) + .fromCase(MULTISET(STRING()).nullable(), null, EMPTY_UTF8, false) + .fromCase(MULTISET(STRING()), null, EMPTY_UTF8, true) + .fromCase( + ROW(FIELD("f0", INT()), FIELD("f1", STRING())), + GenericRowData.of(123, fromString("foo")), + fromString("(123, foo) "), + false) + .fromCase( + ROW(FIELD("f0", INT()), FIELD("f1", STRING())), + GenericRowData.of(123, fromString("foo")), + fromString("(123,foo)"), + true) + .fromCase( + ROW(FIELD("f0", STRING()), FIELD("f1", STRING())).nullable(), + null, + EMPTY_UTF8, + false) + .fromCase( + ROW(FIELD("f0", STRING()), FIELD("f1", STRING())).nullable(), + null, + EMPTY_UTF8, + true) + .fromCase( + RAW(LocalDate.class, new LocalDateSerializer()), + RawValueData.fromObject(LocalDate.parse("2020-12-09")), + fromString("2020-12-09 "), + false) + .fromCase( + RAW(LocalDate.class, new LocalDateSerializer()), + RawValueData.fromObject(LocalDate.parse("2020-12-09")), + fromString("2020-12-09"), + true) + .fromCase( + RAW(LocalDateTime.class, new LocalDateTimeSerializer()).nullable(), + null, + EMPTY_UTF8, + false) + .fromCase( + RAW(LocalDateTime.class, new LocalDateTimeSerializer()).nullable(), + null, + EMPTY_UTF8, + true), + CastTestSpecBuilder.testCastTo(VARCHAR(3)) + .fromCase(STRING(), null, EMPTY_UTF8, false) + .fromCase(STRING(), null, EMPTY_UTF8, true) + .fromCase(CHAR(6), fromString("Apache"), fromString("Apa"), false) + .fromCase(CHAR(6), fromString("Apache"), fromString("Apache"), true) + .fromCase(VARCHAR(5), fromString("Flink"), fromString("Fli"), false) + .fromCase(VARCHAR(5), fromString("Flink"), fromString("Flink"), true) + .fromCase(STRING(), fromString("Apache Flink"), fromString("Apa"), false) + .fromCase( + STRING(), + fromString("Apache Flink"), + fromString("Apache Flink"), + true) + .fromCase(BOOLEAN(), true, fromString("tru"), false) + .fromCase(BOOLEAN(), true, fromString("true"), true) + .fromCase(BOOLEAN(), false, fromString("fal"), false) + .fromCase(BOOLEAN(), false, fromString("false"), true) + .fromCase( + BINARY(5), + new byte[] {0, 1, 2, 3, 4}, + fromString("\u0000\u0001\u0002"), + false) + .fromCase( + BINARY(5), + new byte[] {0, 1, 2, 3, 4}, + fromString("\u0000\u0001\u0002\u0003\u0004"), + true) + .fromCase( + VARBINARY(5), + new byte[] {0, 1, 2, 3, 4}, + fromString("\u0000\u0001\u0002"), + false) + .fromCase( + VARBINARY(5), + new byte[] {0, 1, 2, 3, 4}, + fromString("\u0000\u0001\u0002\u0003\u0004"), + true) + .fromCase( + BYTES(), + new byte[] {0, 1, 2, 3, 4}, + fromString("\u0000\u0001\u0002"), + false) + .fromCase( + BYTES(), + new byte[] {0, 1, 2, 3, 4}, + fromString("\u0000\u0001\u0002\u0003\u0004"), + true) + .fromCase( + DECIMAL(4, 3), + fromBigDecimal(new BigDecimal("9.8765"), 5, 4), + fromString("9.8"), + false) + .fromCase( + DECIMAL(4, 3), + fromBigDecimal(new BigDecimal("9.8765"), 5, 4), + fromString("9.8765"), + true) + .fromCase(TINYINT(), (byte) -125, fromString("-12"), false) + .fromCase(TINYINT(), (byte) -125, fromString("-125"), true) + .fromCase(SMALLINT(), (short) 32767, fromString("327"), false) + .fromCase(SMALLINT(), (short) 32767, fromString("32767"), true) + .fromCase(INT(), -12345678, fromString("-12"), false) + .fromCase(INT(), -12345678, fromString("-12345678"), true) + .fromCase(BIGINT(), 1234567891234L, fromString("123"), false) + .fromCase(BIGINT(), 1234567891234L, fromString("1234567891234"), true) + .fromCase(FLOAT(), -123.456f, fromString("-12"), false) + .fromCase(FLOAT(), -123.456f, fromString("-123.456"), true) + .fromCase(DOUBLE(), 12345.678901d, fromString("123"), false) + .fromCase(DOUBLE(), 12345.678901d, fromString("12345.678901"), true) + .fromCase(FLOAT(), Float.MAX_VALUE, fromString("3.4"), false) + .fromCase( + FLOAT(), + Float.MAX_VALUE, + fromString(String.valueOf(Float.MAX_VALUE)), + true) + .fromCase(DOUBLE(), Double.MAX_VALUE, fromString("1.7"), false) + .fromCase( + DOUBLE(), + Double.MAX_VALUE, + fromString(String.valueOf(Double.MAX_VALUE)), + true) + .fromCase(TIMESTAMP(), TIMESTAMP, fromString("202"), false) + .fromCase(TIMESTAMP(), TIMESTAMP, TIMESTAMP_STRING, true) + .fromCase(TIMESTAMP_LTZ(), CET_CONTEXT, TIMESTAMP, fromString("202")) + .fromCase( + TIMESTAMP_LTZ(), + CET_CONTEXT_LEGACY, + TIMESTAMP, + TIMESTAMP_STRING_CET) + .fromCase(DATE(), DATE, fromString("202"), false) + .fromCase(DATE(), DATE, DATE_STRING, true) + .fromCase(TIME(5), TIME, fromString("12:"), false) + .fromCase(TIME(5), TIME, TIME_STRING, true) + .fromCase(INTERVAL(YEAR()), 84, fromString("+7-"), false) + .fromCase(INTERVAL(YEAR()), 84, fromString("+7-00"), true) + .fromCase(INTERVAL(MONTH()), 5, fromString("+0-"), false) + .fromCase(INTERVAL(MONTH()), 5, fromString("+0-05"), true) + .fromCase(INTERVAL(DAY()), 10L, fromString("+0 "), false) + .fromCase(INTERVAL(DAY()), 10L, fromString("+0 00:00:00.010"), true) + .fromCase( + ARRAY(INT()), + new GenericArrayData(new int[] {-123, 456}), + fromString("[-1"), + false) + .fromCase( + ARRAY(INT()), + new GenericArrayData(new int[] {-123, 456}), + fromString("[-123, 456]"), + true) + .fromCase(ARRAY(INT()).nullable(), null, EMPTY_UTF8, false) + .fromCase(ARRAY(INT()).nullable(), null, EMPTY_UTF8, true) + .fromCase( + MAP(STRING(), INTERVAL(MONTH())), + mapData(entry(fromString("a"), -123), entry(fromString("b"), 123)), + fromString("{a="), + false) + .fromCase( + MAP(STRING(), INTERVAL(MONTH())), + mapData(entry(fromString("a"), -123), entry(fromString("b"), 123)), + fromString("{a=-10-03, b=+10-03}"), + true) + .fromCase( + MAP(STRING(), INTERVAL(MONTH())).nullable(), + null, + EMPTY_UTF8, + false) + .fromCase( + MAP(STRING(), INTERVAL(MONTH())).nullable(), null, EMPTY_UTF8, true) + .fromCase( + MAP(STRING(), INTERVAL(MONTH())).nullable(), + null, + EMPTY_UTF8, + false) + .fromCase( + MULTISET(STRING()), + mapData(entry(fromString("a"), 1), entry(fromString("b"), 1)), + fromString("{a="), + false) + .fromCase( + MULTISET(STRING()), + mapData(entry(fromString("a"), 1), entry(fromString("b"), 1)), + fromString("{a=1, b=1}"), + true) + .fromCase(MULTISET(STRING()).nullable(), null, EMPTY_UTF8, false) + .fromCase(MULTISET(STRING()), null, EMPTY_UTF8, true) + .fromCase( + ROW(FIELD("f0", INT()), FIELD("f1", STRING())), + GenericRowData.of(123, fromString("abc")), + fromString("(12"), + false) + .fromCase( + ROW(FIELD("f0", INT()), FIELD("f1", STRING())), + GenericRowData.of(123, fromString("abc")), + fromString("(123,abc)"), + true) + .fromCase( + ROW(FIELD("f0", STRING()), FIELD("f1", STRING())).nullable(), + null, + EMPTY_UTF8, + false) + .fromCase( + ROW(FIELD("f0", STRING()), FIELD("f1", STRING())).nullable(), + null, + EMPTY_UTF8, + true) + .fromCase( + RAW(LocalDateTime.class, new LocalDateTimeSerializer()), + RawValueData.fromObject( + LocalDateTime.parse("2020-11-11T18:08:01.123")), + fromString("202"), + false) + .fromCase( + RAW(LocalDateTime.class, new LocalDateTimeSerializer()), + RawValueData.fromObject( + LocalDateTime.parse("2020-11-11T18:08:01.123")), + fromString("2020-11-11T18:08:01.123"), + true) + .fromCase( + RAW(LocalDateTime.class, new LocalDateTimeSerializer()).nullable(), + null, + EMPTY_UTF8, + false) + .fromCase( + RAW(LocalDateTime.class, new LocalDateTimeSerializer()).nullable(), + null, + EMPTY_UTF8, + true), CastTestSpecBuilder.testCastTo(BOOLEAN()) .fromCase(BOOLEAN(), null, null) .fail(CHAR(3), fromString("foo"), TableException.class) diff --git a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/expressions/ScalarFunctionsTest.scala b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/expressions/ScalarFunctionsTest.scala index 396ebca..410b5ed 100644 --- a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/expressions/ScalarFunctionsTest.scala +++ b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/expressions/ScalarFunctionsTest.scala @@ -4303,10 +4303,10 @@ class ScalarFunctionsTest extends ScalarTypesTestBase { val url = "CAST('http://user:pass@host' AS VARCHAR(50))" val base64 = "CAST('aGVsbG8gd29ybGQ=' AS VARCHAR(20))" - testSqlApi(s"IFNULL(SUBSTR($str1, 2, 3), $str2)", "ell") - testSqlApi(s"IFNULL(SUBSTRING($str1, 2, 3), $str2)", "ell") - testSqlApi(s"IFNULL(LEFT($str1, 3), $str2)", "Hel") - testSqlApi(s"IFNULL(RIGHT($str1, 3), $str2)", "llo") + testSqlApi(s"IFNULL(SUBSTR($str1, 2, 3), $str2)", "el") + testSqlApi(s"IFNULL(SUBSTRING($str1, 2, 3), $str2)", "el") + testSqlApi(s"IFNULL(LEFT($str1, 3), $str2)", "He") + testSqlApi(s"IFNULL(RIGHT($str1, 3), $str2)", "ll") testSqlApi(s"IFNULL(REGEXP_EXTRACT($str1, 'H(.+?)l(.+?)$$', 2), $str2)", "lo") testSqlApi(s"IFNULL(REGEXP_REPLACE($str1, 'e.l', 'EXL'), $str2)", "HEXLo") testSqlApi(s"IFNULL(UPPER($str1), $str2)", "HELLO") @@ -4316,9 +4316,9 @@ class ScalarFunctionsTest extends ScalarTypesTestBase { testSqlApi(s"IFNULL(LPAD($str1, 7, $str3), $str2)", "heHello") testSqlApi(s"IFNULL(RPAD($str1, 7, $str3), $str2)", "Hellohe") testSqlApi(s"IFNULL(REPEAT($str1, 2), $str2)", "HelloHello") - testSqlApi(s"IFNULL(REVERSE($str1), $str2)", "olleH") + testSqlApi(s"IFNULL(REVERSE($str1), $str2)", "ol") testSqlApi(s"IFNULL(REPLACE($str3, ' ', '_'), $str2)", "hello_world") - testSqlApi(s"IFNULL(SPLIT_INDEX($str3, ' ', 1), $str2)", "world") + testSqlApi(s"IFNULL(SPLIT_INDEX($str3, ' ', 1), $str2)", "wo") testSqlApi(s"IFNULL(MD5($str1), $str2)", "8b1a9953c4611296a827abf8c47804d7") testSqlApi(s"IFNULL(SHA1($str1), $str2)", "f7ff9e8b7bb2e09b70935a5d785e0cc5d9d0abf0") testSqlApi( @@ -4338,7 +4338,7 @@ class ScalarFunctionsTest extends ScalarTypesTestBase { testSqlApi( s"IFNULL(SHA2($str1, 256), $str2)", "185f8db32271fe25f561a6fc938b2e264306ec304eda518007d1764826381969") - testSqlApi(s"IFNULL(PARSE_URL($url, 'HOST'), $str2)", "host") + testSqlApi(s"IFNULL(PARSE_URL($url, 'HOST'), $str2)", "ho") testSqlApi(s"IFNULL(FROM_BASE64($base64), $str2)", "hello world") testSqlApi(s"IFNULL(TO_BASE64($str3), $str2)", "aGVsbG8gd29ybGQ=") testSqlApi(s"IFNULL(CHR(65), $str2)", "A") @@ -4350,7 +4350,7 @@ class ScalarFunctionsTest extends ScalarTypesTestBase { testSqlApi(s"IFNULL(RTRIM($str4), $str2)", " hello") testSqlApi(s"IFNULL($str1 || $str2, $str2)", "HelloHi") testSqlApi(s"IFNULL(SUBSTRING(UUID(), 9, 1), $str2)", "-") - testSqlApi(s"IFNULL(DECODE(ENCODE($str1, 'utf-8'), 'utf-8'), $str2)", "Hello") + testSqlApi(s"IFNULL(DECODE(ENCODE($str1, 'utf-8'), 'utf-8'), $str2)", "He") testSqlApi(s"IFNULL(CAST(DATE '2021-04-06' AS VARCHAR(10)), $str2)", "2021-04-06") testSqlApi(s"IFNULL(CAST(TIME '11:05:30' AS VARCHAR(8)), $str2)", "11:05:30")