This is an automated email from the ASF dual-hosted git repository.
fanjia pushed a commit to branch dev
in repository https://gitbox.apache.org/repos/asf/seatunnel.git
The following commit(s) were added to refs/heads/dev by this push:
new eae369b804 [Improve][transform-v2] Support dynamic types for array
function (#8139)
eae369b804 is described below
commit eae369b804d79f0c32b62e7fe42806069348adc5
Author: CosmosNi <[email protected]>
AuthorDate: Thu Dec 26 10:05:09 2024 +0800
[Improve][transform-v2] Support dynamic types for array function (#8139)
---
docs/en/transform-v2/sql-functions.md | 11 +-
docs/zh/transform-v2/sql-functions.md | 10 +-
.../resources/sql_transform/explode_transform.conf | 2 +-
.../explode_transform_without_outer.conf | 2 +-
.../test/resources/sql_transform/func_array.conf | 73 ++++++-
.../transform/sql/zeta/ZetaSQLEngine.java | 10 +-
.../transform/sql/zeta/ZetaSQLFunction.java | 29 +--
.../seatunnel/transform/sql/zeta/ZetaSQLType.java | 2 +
.../sql/zeta/functions/ArrayFunction.java | 222 +++++++++++++++++++++
9 files changed, 327 insertions(+), 34 deletions(-)
diff --git a/docs/en/transform-v2/sql-functions.md
b/docs/en/transform-v2/sql-functions.md
index 5c5f869dd4..a613b41356 100644
--- a/docs/en/transform-v2/sql-functions.md
+++ b/docs/en/transform-v2/sql-functions.md
@@ -998,8 +998,15 @@ Generate an array.
Example:
-select ARRAY('test1','test2','test3') as arrays
-
+SELECT Array('c_1','c_2') as string_array,
+ Array(1.23,2.34) as double_array,
+ Array(1,2) as int_array,
+ Array(2147483648,2147483649) as long_array,
+ Array(1.23,2147483648) as double_array_1,
+ Array(1.23,2147483648,'c_1') as string_array_1
+FROM fake
+
+notes: Currently only string, double, long, int types are supported
### LATERAL VIEW
#### EXPLODE
diff --git a/docs/zh/transform-v2/sql-functions.md
b/docs/zh/transform-v2/sql-functions.md
index 7e3f8454e1..a5c616926f 100644
--- a/docs/zh/transform-v2/sql-functions.md
+++ b/docs/zh/transform-v2/sql-functions.md
@@ -991,7 +991,15 @@ select UUID() as seatunnel_uuid
示例:
-select ARRAY('test1','test2','test3') as arrays
+SELECT Array('c_1','c_2') as string_array,
+ Array(1.23,2.34) as double_array,
+ Array(1,2) as int_array,
+ Array(2147483648,2147483649) as long_array,
+ Array(1.23,2147483648) as double_array_1,
+ Array(1.23,2147483648,'c_1') as string_array_1
+FROM fake
+
+注意:目前仅支持string、double、long、int几种类型
### LATERAL VIEW
#### EXPLODE
diff --git
a/seatunnel-e2e/seatunnel-transforms-v2-e2e/seatunnel-transforms-v2-e2e-part-2/src/test/resources/sql_transform/explode_transform.conf
b/seatunnel-e2e/seatunnel-transforms-v2-e2e/seatunnel-transforms-v2-e2e-part-2/src/test/resources/sql_transform/explode_transform.conf
index 1bfb6d18ff..c6ef17abd0 100644
---
a/seatunnel-e2e/seatunnel-transforms-v2-e2e/seatunnel-transforms-v2-e2e-part-2/src/test/resources/sql_transform/explode_transform.conf
+++
b/seatunnel-e2e/seatunnel-transforms-v2-e2e/seatunnel-transforms-v2-e2e-part-2/src/test/resources/sql_transform/explode_transform.conf
@@ -94,7 +94,7 @@ sink{
},
{
field_name = num
- field_type = string
+ field_type = int
field_value = [{equals_to = 1}]
}
]
diff --git
a/seatunnel-e2e/seatunnel-transforms-v2-e2e/seatunnel-transforms-v2-e2e-part-2/src/test/resources/sql_transform/explode_transform_without_outer.conf
b/seatunnel-e2e/seatunnel-transforms-v2-e2e/seatunnel-transforms-v2-e2e-part-2/src/test/resources/sql_transform/explode_transform_without_outer.conf
index b5c96050c6..0601281ca1 100644
---
a/seatunnel-e2e/seatunnel-transforms-v2-e2e/seatunnel-transforms-v2-e2e-part-2/src/test/resources/sql_transform/explode_transform_without_outer.conf
+++
b/seatunnel-e2e/seatunnel-transforms-v2-e2e/seatunnel-transforms-v2-e2e-part-2/src/test/resources/sql_transform/explode_transform_without_outer.conf
@@ -86,7 +86,7 @@ sink{
},
{
field_name = num
- field_type = "string"
+ field_type = "int"
field_value = [{equals_to = 1}]
}
]
diff --git
a/seatunnel-e2e/seatunnel-transforms-v2-e2e/seatunnel-transforms-v2-e2e-part-2/src/test/resources/sql_transform/func_array.conf
b/seatunnel-e2e/seatunnel-transforms-v2-e2e/seatunnel-transforms-v2-e2e-part-2/src/test/resources/sql_transform/func_array.conf
index b743419cfe..6db9535e0c 100644
---
a/seatunnel-e2e/seatunnel-transforms-v2-e2e/seatunnel-transforms-v2-e2e-part-2/src/test/resources/sql_transform/func_array.conf
+++
b/seatunnel-e2e/seatunnel-transforms-v2-e2e/seatunnel-transforms-v2-e2e-part-2/src/test/resources/sql_transform/func_array.conf
@@ -18,6 +18,7 @@
###### This config file is a demonstration of streaming processing in
seatunnel config
######
+
env {
job.mode = "BATCH"
parallelism = 1
@@ -30,6 +31,7 @@ source {
fields {
pk_id = string
name = string
+ id = int
}
primaryKey {
name = "pk_id"
@@ -39,7 +41,7 @@ source {
rows = [
{
kind = INSERT
- fields = ["id001", "zhangsan,zhangsan"]
+ fields = ["id001", "zhangsan,zhangsan",123]
}
]
}
@@ -47,14 +49,25 @@ source {
transform {
Sql {
- plugin_input = "fake"
- plugin_output = "fake1"
- query = "SELECT *,Array('c_1','c_2') as c_array FROM dual "
+ plugin_output = "fake"
+ query = """SELECT
+ *,
+ Array(pk_id,id) as field_array_1,
+ Array(pk_id,'c_1') as field_array_2,
+ Array(id,123) as field_array_3,
+ Array('c_1','c_2') as string_array,
+ Array(1.23,2.34) as double_array,
+ Array(1,2) as int_array,
+ Array(2147483648,2147483649) as long_array,
+ Array(1.23,2147483648) as double_array_1,
+ Array(1.23,2147483648,'c_1') as string_array_1
+ FROM fake """
}
}
sink{
assert {
+ plugin_output = "fake"
rules =
{
row_rules = [
@@ -79,12 +92,56 @@ sink{
field_value = [{equals_to = "zhangsan,zhangsan"}]
},
{
- field_name = c_array
- field_type = array<string>
+ field_name = id
+ field_type = int
+ field_value = [{equals_to = 123}]
+ },
+ {
+ field_name = field_array_1
+ field_type = array<STRING>
+ field_value = [{equals_to = ["id001" ,"123"]}]
+ },
+ {
+ field_name = field_array_2
+ field_type = array<STRING>
+ field_value = [{equals_to = ["id001" ,"c_1"]}]
+ },
+ {
+ field_name = field_array_3
+ field_type = array<INT>
+ field_value = [{equals_to = [123 ,123]}]
+ },
+ {
+ field_name = string_array
+ field_type = array<STRING>
field_value = [{equals_to = ["c_1" ,"c_2"]}]
- }
+ },
+ {
+ field_name = double_array
+ field_type = array<DOUBLE>
+ field_value = [{equals_to = [1.23,2.34]}]
+ },
+ {
+ field_name = int_array
+ field_type = array<INT>
+ field_value = [{equals_to = [1,2]}]
+ },
+ {
+ field_name = long_array
+ field_type = array<BIGINT>
+ field_value = [{equals_to = [2147483648,2147483649]}]
+ },
+ {
+ field_name = double_array_1
+ field_type = array<DOUBLE>
+ field_value = [{equals_to = [1.23,2147483648]}]
+ },
+ {
+ field_name = string_array_1
+ field_type = array<STRING>
+ field_value = [{equals_to = ["1.23","2147483648","c_1"]}]
+ }
]
}
}
}
-
diff --git
a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/sql/zeta/ZetaSQLEngine.java
b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/sql/zeta/ZetaSQLEngine.java
index d26e47de3c..9ce552c289 100644
---
a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/sql/zeta/ZetaSQLEngine.java
+++
b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/sql/zeta/ZetaSQLEngine.java
@@ -59,6 +59,7 @@ public class ZetaSQLEngine implements SQLEngine {
private String inputTableName;
@Nullable private String catalogTableName;
private SeaTunnelRowType inputRowType;
+ private SeaTunnelRowType outRowType;
private String sql;
private PlainSelect selectBody;
@@ -216,10 +217,13 @@ public class ZetaSQLEngine implements SQLEngine {
}
List<LateralView> lateralViews = selectBody.getLateralViews();
if (CollectionUtils.isEmpty(lateralViews)) {
- return new SeaTunnelRowType(fieldNames, seaTunnelDataTypes);
+ outRowType = new SeaTunnelRowType(fieldNames, seaTunnelDataTypes);
+ } else {
+ outRowType =
+ zetaSQLFunction.lateralViewMapping(
+ fieldNames, seaTunnelDataTypes, lateralViews,
inputColumnsMapping);
}
- return zetaSQLFunction.lateralViewMapping(
- fieldNames, seaTunnelDataTypes, lateralViews,
inputColumnsMapping);
+ return outRowType;
}
private static String cleanEscape(String columnName) {
diff --git
a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/sql/zeta/ZetaSQLFunction.java
b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/sql/zeta/ZetaSQLFunction.java
index 8cbc3ed86a..8a69142c1d 100644
---
a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/sql/zeta/ZetaSQLFunction.java
+++
b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/sql/zeta/ZetaSQLFunction.java
@@ -19,7 +19,6 @@ package org.apache.seatunnel.transform.sql.zeta;
import org.apache.seatunnel.api.table.catalog.PhysicalColumn;
import org.apache.seatunnel.api.table.type.ArrayType;
-import org.apache.seatunnel.api.table.type.BasicType;
import org.apache.seatunnel.api.table.type.DecimalType;
import org.apache.seatunnel.api.table.type.MapType;
import org.apache.seatunnel.api.table.type.SeaTunnelDataType;
@@ -29,6 +28,7 @@ import org.apache.seatunnel.api.table.type.SqlType;
import org.apache.seatunnel.common.exception.CommonErrorCodeDeprecated;
import org.apache.seatunnel.common.exception.SeaTunnelRuntimeException;
import org.apache.seatunnel.transform.exception.TransformException;
+import org.apache.seatunnel.transform.sql.zeta.functions.ArrayFunction;
import org.apache.seatunnel.transform.sql.zeta.functions.DateTimeFunction;
import org.apache.seatunnel.transform.sql.zeta.functions.NumericFunction;
import org.apache.seatunnel.transform.sql.zeta.functions.StringFunction;
@@ -192,6 +192,7 @@ public class ZetaSQLFunction {
public static final String UUID = "UUID";
private final SeaTunnelRowType inputRowType;
+
private final ZetaSQLType zetaSQLType;
private final ZetaSQLFilter zetaSQLFilter;
@@ -552,7 +553,7 @@ public class ZetaSQLFunction {
case NULLIF:
return SystemFunction.nullif(args);
case ARRAY:
- return SystemFunction.array(args);
+ return ArrayFunction.array(args);
case UUID:
return randomUUID().toString();
default:
@@ -743,8 +744,7 @@ public class ZetaSQLFunction {
next,
aliasFieldIndex,
row,
- expression,
- true);
+ expression);
}
seaTunnelRows = next;
} else if (expression instanceof Function) {
@@ -758,8 +758,7 @@ public class ZetaSQLFunction {
next,
aliasFieldIndex,
row,
- expression,
- false);
+ expression);
}
seaTunnelRows = next;
}
@@ -774,8 +773,7 @@ public class ZetaSQLFunction {
List<SeaTunnelRow> next,
int aliasFieldIndex,
SeaTunnelRow row,
- Expression expression,
- boolean keepValueType) {
+ Expression expression) {
if (splitFieldValue == null) {
if (isUsingOuter) {
next.add(
@@ -798,13 +796,9 @@ public class ZetaSQLFunction {
if (!isUsingOuter && fieldValue == null) {
continue;
}
- Object value =
- fieldValue == null
- ? null
- : (keepValueType ? fieldValue :
String.valueOf(fieldValue));
next.add(
copySeaTunnelRowWithNewValue(
- outRowType.getTotalFields(), row,
aliasFieldIndex, value));
+ outRowType.getTotalFields(), row,
aliasFieldIndex, fieldValue));
}
} else {
throw new SeaTunnelRuntimeException(
@@ -865,14 +859,13 @@ public class ZetaSQLFunction {
seaTunnelDataTypes[columnIndex] =
seaTunnelDataType;
}
} else {
- // default string type
- SeaTunnelDataType seaTunnelDataType =
- PhysicalColumn.of(alias,
BasicType.STRING_TYPE, 10L, true, "", "")
- .getDataType();
+
+ ArrayType arrayType = (ArrayType)
zetaSQLType.getExpressionType(expression);
+
if (aliasIndex == -1) {
fieldNames = ArrayUtils.add(fieldNames, alias);
seaTunnelDataTypes =
- ArrayUtils.add(seaTunnelDataTypes,
seaTunnelDataType);
+ ArrayUtils.add(seaTunnelDataTypes,
arrayType.getElementType());
inputColumnsMapping.add(alias);
}
}
diff --git
a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/sql/zeta/ZetaSQLType.java
b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/sql/zeta/ZetaSQLType.java
index 127479536c..fda9105179 100644
---
a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/sql/zeta/ZetaSQLType.java
+++
b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/sql/zeta/ZetaSQLType.java
@@ -28,6 +28,7 @@ import org.apache.seatunnel.api.table.type.SeaTunnelRowType;
import org.apache.seatunnel.api.table.type.SqlType;
import org.apache.seatunnel.common.exception.CommonErrorCodeDeprecated;
import org.apache.seatunnel.transform.exception.TransformException;
+import org.apache.seatunnel.transform.sql.zeta.functions.ArrayFunction;
import org.apache.commons.collections4.CollectionUtils;
@@ -448,6 +449,7 @@ public class ZetaSQLType {
case ZetaSQLFunction.TRUNCATE:
return BasicType.DOUBLE_TYPE;
case ZetaSQLFunction.ARRAY:
+ return ArrayFunction.castArrayTypeMapping(function,
inputRowType);
case ZetaSQLFunction.SPLIT:
return ArrayType.STRING_ARRAY_TYPE;
case ZetaSQLFunction.NOW:
diff --git
a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/sql/zeta/functions/ArrayFunction.java
b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/sql/zeta/functions/ArrayFunction.java
new file mode 100644
index 0000000000..dcbe814d6b
--- /dev/null
+++
b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/sql/zeta/functions/ArrayFunction.java
@@ -0,0 +1,222 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.seatunnel.transform.sql.zeta.functions;
+
+import org.apache.seatunnel.api.table.type.ArrayType;
+import org.apache.seatunnel.api.table.type.SeaTunnelRowType;
+import org.apache.seatunnel.common.utils.SeaTunnelException;
+
+import net.sf.jsqlparser.expression.DoubleValue;
+import net.sf.jsqlparser.expression.Expression;
+import net.sf.jsqlparser.expression.Function;
+import net.sf.jsqlparser.expression.LongValue;
+import net.sf.jsqlparser.expression.NullValue;
+import net.sf.jsqlparser.expression.StringValue;
+import net.sf.jsqlparser.expression.operators.relational.ExpressionList;
+import net.sf.jsqlparser.schema.Column;
+
+import java.util.ArrayList;
+import java.util.List;
+
+public class ArrayFunction {
+
+ public static Object[] array(List<Object> args) {
+ if (args == null || args.isEmpty()) {
+ return new Object[0];
+ }
+ Class<?> arrayType = getDataClassType(args);
+ Object[] result = (Object[])
java.lang.reflect.Array.newInstance(arrayType, args.size());
+ for (int i = 0; i < args.size(); i++) {
+ result[i] = convertToType(args.get(i), arrayType);
+ }
+
+ return result;
+ }
+
+ public static ArrayType castArrayTypeMapping(Function function,
SeaTunnelRowType inputRowType) {
+ return castArrayTypeMapping(getFunctionArgs(function, inputRowType));
+ }
+
+ public static ArrayType castArrayTypeMapping(List<Class<?>> args) {
+ if (args == null || args.isEmpty()) {
+ return ArrayType.STRING_ARRAY_TYPE;
+ }
+
+ Class<?> arrayType = getClassType(args);
+ return getSeaTunnelDataType(arrayType);
+ }
+
+ private static ArrayType getSeaTunnelDataType(Class<?> clazz) {
+ String className = clazz.getSimpleName();
+ switch (className) {
+ case "Integer":
+ return ArrayType.INT_ARRAY_TYPE;
+ case "Double":
+ return ArrayType.DOUBLE_ARRAY_TYPE;
+ case "Boolean":
+ return ArrayType.BOOLEAN_ARRAY_TYPE;
+ case "Long":
+ return ArrayType.LONG_ARRAY_TYPE;
+ case "float":
+ return ArrayType.FLOAT_ARRAY_TYPE;
+ case "short":
+ return ArrayType.SHORT_ARRAY_TYPE;
+ default:
+ return ArrayType.STRING_ARRAY_TYPE;
+ }
+ }
+
+ private static Class<?> getArrayType(Class<?> type1, Class<?> type2) {
+ if (type1.isAssignableFrom(type2)) {
+ return type1;
+ }
+ if (type2.isAssignableFrom(type1)) {
+ return type2;
+ }
+ if (isNumericType(type1) && isNumericType(type2)) {
+ return getNumericCommonType(type1, type2);
+ }
+ return String.class;
+ }
+
+ private static boolean isNumericType(Class<?> type) {
+ return type == Short.class
+ || type == Integer.class
+ || type == Long.class
+ || type == Float.class
+ || type == Double.class;
+ }
+
+ private static Class<?> getNumericCommonType(Class<?> type1, Class<?>
type2) {
+ if (type1 == Double.class || type2 == Double.class) {
+ return Double.class;
+ }
+ if (type1 == Float.class || type2 == Float.class) {
+ return Float.class;
+ }
+ if (type1 == Long.class || type2 == Long.class) {
+ return Long.class;
+ }
+ if (type1 == Integer.class || type2 == Integer.class) {
+ return Integer.class;
+ }
+ if (type1 == Short.class || type2 == Short.class) {
+ return Short.class;
+ }
+ return String.class;
+ }
+
+ private static Class<?> getClassType(List<Class<?>> args) {
+ Class<?> arrayType = null;
+ for (Class<?> obj : args) {
+ if (obj == null) {
+ continue;
+ }
+ if (arrayType == null) {
+ arrayType = obj;
+ } else {
+ arrayType = getArrayType(arrayType, obj);
+ }
+ }
+ return arrayType == null ? String.class : arrayType;
+ }
+
+ private static Class<?> getDataClassType(List<Object> args) {
+ Class<?> arrayType = null;
+ for (Object obj : args) {
+ if (obj == null) {
+ continue;
+ }
+ if (arrayType == null) {
+ arrayType = obj.getClass();
+ } else {
+ arrayType = getArrayType(arrayType, obj.getClass());
+ }
+ }
+ return arrayType == null ? String.class : arrayType;
+ }
+
+ private static List<Class<?>> getFunctionArgs(
+ Function function, SeaTunnelRowType inputRowType) {
+ ExpressionList<Expression> expressionList =
+ (ExpressionList<Expression>) function.getParameters();
+ List<Class<?>> functionArgs = new ArrayList<>();
+ if (expressionList != null) {
+ for (Expression expression : expressionList.getExpressions()) {
+ if (expression instanceof NullValue) {
+ functionArgs.add(null);
+ continue;
+ }
+ if (expression instanceof DoubleValue) {
+ functionArgs.add(Double.class);
+ continue;
+ }
+ if (expression instanceof Column) {
+ int columnIndex = inputRowType.indexOf(((Column)
expression).getColumnName());
+
functionArgs.add(inputRowType.getFieldType(columnIndex).getTypeClass());
+ continue;
+ }
+
+ if (expression instanceof LongValue) {
+ long longVal = ((LongValue) expression).getValue();
+ if (longVal <= Integer.MAX_VALUE && longVal >=
Integer.MIN_VALUE) {
+ functionArgs.add(Integer.class);
+ } else {
+ functionArgs.add(Long.class);
+ }
+ continue;
+ }
+ if (expression instanceof StringValue) {
+ functionArgs.add(String.class);
+ continue;
+ }
+ throw new SeaTunnelException("unSupport expression: " +
expression.toString());
+ }
+ }
+ return functionArgs;
+ }
+
+ private static Object convertToType(Object obj, Class<?> targetType) {
+ if (obj == null || targetType.isInstance(obj)) {
+ return obj;
+ }
+
+ if (targetType == Double.class) {
+ return ((Number) obj).doubleValue();
+ }
+ if (targetType == Float.class) {
+ return ((Number) obj).floatValue();
+ }
+ if (targetType == Long.class) {
+ return ((Number) obj).longValue();
+ }
+ if (targetType == Integer.class) {
+ return ((Number) obj).intValue();
+ }
+ if (targetType == Short.class) {
+ return ((Number) obj).shortValue();
+ }
+ if (targetType == Byte.class) {
+ return ((Number) obj).byteValue();
+ }
+ if (targetType == String.class) {
+ return obj.toString();
+ }
+
+ throw new SeaTunnelException("Cannot convert " + obj.getClass() + " to
" + targetType);
+ }
+}