This is an automated email from the ASF dual-hosted git repository.
exmy pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git
The following commit(s) were added to refs/heads/main by this push:
new dab93e8216 Revert "cast for values built from nothing types (#9042)"
(#9086)
dab93e8216 is described below
commit dab93e82167f8759c61e2d277761cbcc491183a1
Author: lgbo <[email protected]>
AuthorDate: Fri Mar 21 16:59:23 2025 +0800
Revert "cast for values built from nothing types (#9042)" (#9086)
This reverts commit 5c000cd923c1f5dec7397bf5b6916edc6df52d6f.
---
.../GlutenClickHouseTPCHSaltNullParquetSuite.scala | 67 -------
cpp-ch/local-engine/Parser/ExpressionParser.cpp | 216 ++++++++-------------
cpp-ch/local-engine/Parser/ExpressionParser.h | 5 -
3 files changed, 83 insertions(+), 205 deletions(-)
diff --git
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/tpch/GlutenClickHouseTPCHSaltNullParquetSuite.scala
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/tpch/GlutenClickHouseTPCHSaltNullParquetSuite.scala
index affc35bc27..3ce03565e8 100644
---
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/tpch/GlutenClickHouseTPCHSaltNullParquetSuite.scala
+++
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/tpch/GlutenClickHouseTPCHSaltNullParquetSuite.scala
@@ -3396,72 +3396,5 @@ class GlutenClickHouseTPCHSaltNullParquetSuite extends
GlutenClickHouseTPCHAbstr
compareResultsAgainstVanillaSpark(select_sql, true, { _ => })
}
- test("GLUTEN-9032 default values from nothing types") {
- val sql1 =
- """
- |select a, b, c from (
- | select
- | n_regionkey as a, n_nationkey as b, array() as c
- | from nation where n_nationkey % 2 = 1
- | union all
- | select n_nationkey as a, n_regionkey as b, array('123') as c
- | from nation where n_nationkey % 2 = 0
- |)
- """.stripMargin
- compareResultsAgainstVanillaSpark(sql1, true, { _ => })
-
- val sql2 =
- """
- |select a, b, c from (
- | select
- | n_regionkey as a, n_nationkey as b, array() as c
- | from nation where n_nationkey % 2 = 1
- | union all
- | select n_nationkey as a, n_regionkey as b, array('123', null) as c
- | from nation where n_nationkey % 2 = 0
- |)
- """.stripMargin
- compareResultsAgainstVanillaSpark(sql2, true, { _ => })
-
- val sql3 =
- """
- |select a, b, c from (
- | select
- | n_regionkey as a, n_nationkey as b, array() as c
- | from nation where n_nationkey % 2 = 1
- | union all
- | select n_nationkey as a, n_regionkey as b, array(null) as c
- | from nation where n_nationkey % 2 = 0
- |)
- """.stripMargin
- compareResultsAgainstVanillaSpark(sql3, true, { _ => })
-
- val sql4 =
- """
- |select a, b, c from (
- | select
- | n_regionkey as a, n_nationkey as b, map() as c
- | from nation where n_nationkey % 2 = 1
- | union all
- | select n_nationkey as a, n_regionkey as b, map('123', 1) as c
- | from nation where n_nationkey % 2 = 0
- |)
- """.stripMargin
- compareResultsAgainstVanillaSpark(sql4, true, { _ => })
-
- val sql5 =
- """
- |select a, b, c from (
- | select
- | n_regionkey as a, n_nationkey as b, map() as c
- | from nation where n_nationkey % 2 = 1
- | union all
- | select n_nationkey as a, n_regionkey as b, map('123', null) as c
- | from nation where n_nationkey % 2 = 0
- |)
- """.stripMargin
- compareResultsAgainstVanillaSpark(sql5, true, { _ => })
-
- }
}
// scalastyle:on line.size.limit
diff --git a/cpp-ch/local-engine/Parser/ExpressionParser.cpp
b/cpp-ch/local-engine/Parser/ExpressionParser.cpp
index f297f94fc5..4261241f25 100644
--- a/cpp-ch/local-engine/Parser/ExpressionParser.cpp
+++ b/cpp-ch/local-engine/Parser/ExpressionParser.cpp
@@ -305,7 +305,89 @@ ExpressionParser::NodeRawConstPtr
ExpressionParser::parseExpression(ActionsDAG &
case substrait::Expression::RexTypeCase::kCast: {
if (!rel.cast().has_type() || !rel.cast().has_input())
throw Exception(ErrorCodes::BAD_ARGUMENTS, "Doesn't have type
or input in cast node.");
- return parseCast(actions_dag, rel);
+ ActionsDAG::NodeRawConstPtrs args;
+
+ const auto & input = rel.cast().input();
+ args.emplace_back(parseExpression(actions_dag, input));
+
+ const auto & substrait_type = rel.cast().type();
+ const auto & input_type = args[0]->result_type;
+ DataTypePtr denull_input_type = removeNullable(input_type);
+ DataTypePtr output_type = TypeParser::parseType(substrait_type);
+ DataTypePtr denull_output_type = removeNullable(output_type);
+ const ActionsDAG::Node * result_node = nullptr;
+ if (substrait_type.has_binary())
+ {
+ /// Spark cast(x as BINARY) -> CH reinterpretAsStringSpark(x)
+ result_node = toFunctionNode(actions_dag,
"reinterpretAsStringSpark", args);
+ }
+ else if (isString(denull_input_type) &&
isDate32(denull_output_type))
+ result_node = toFunctionNode(actions_dag, "sparkToDate", args);
+ else if (isString(denull_input_type) &&
isDateTime64(denull_output_type))
+ result_node = toFunctionNode(actions_dag, "sparkToDateTime",
args);
+ else if (isDecimal(denull_input_type) &&
isString(denull_output_type))
+ {
+ /// Spark cast(x as STRING) if x is Decimal -> CH
toDecimalString(x, scale)
+ UInt8 scale = getDecimalScale(*denull_input_type);
+ args.emplace_back(addConstColumn(actions_dag,
std::make_shared<DataTypeUInt8>(), Field(scale)));
+ result_node = toFunctionNode(actions_dag, "toDecimalString",
args);
+ }
+ else if (isFloat(denull_input_type) && isInt(denull_output_type))
+ {
+ String function_name = "sparkCastFloatTo" +
denull_output_type->getName();
+ result_node = toFunctionNode(actions_dag, function_name, args);
+ }
+ else if (isFloat(denull_input_type) &&
isString(denull_output_type))
+ result_node = toFunctionNode(actions_dag,
"sparkCastFloatToString", args);
+ else if ((isDecimal(denull_input_type) ||
isNativeNumber(denull_input_type)) && substrait_type.has_decimal())
+ {
+ int precision = substrait_type.decimal().precision();
+ int scale = substrait_type.decimal().scale();
+ if (precision)
+ {
+ args.emplace_back(addConstColumn(actions_dag,
std::make_shared<DataTypeInt32>(), precision));
+ args.emplace_back(addConstColumn(actions_dag,
std::make_shared<DataTypeInt32>(), scale));
+ result_node = toFunctionNode(actions_dag,
"checkDecimalOverflowSparkOrNull", args);
+ }
+ }
+ else if (isMap(denull_input_type) && isString(denull_output_type))
+ {
+ // ISSUE-7389: spark cast(map to string) has different
behavior with CH cast(map to string)
+ auto map_input_type = std::static_pointer_cast<const
DataTypeMap>(denull_input_type);
+ args.emplace_back(addConstColumn(actions_dag,
map_input_type->getKeyType(), map_input_type->getKeyType()->getDefault()));
+ args.emplace_back(
+ addConstColumn(actions_dag,
map_input_type->getValueType(), map_input_type->getValueType()->getDefault()));
+ result_node = toFunctionNode(actions_dag,
"sparkCastMapToString", args);
+ }
+ else if (isArray(denull_input_type) &&
isString(denull_output_type))
+ {
+ // ISSUE-7602: spark cast(array to string) has different
result with CH cast(array to string)
+ result_node = toFunctionNode(actions_dag,
"sparkCastArrayToString", args);
+ }
+ else if (isString(denull_input_type) && substrait_type.has_bool_())
+ {
+ /// cast(string to boolean)
+ args.emplace_back(addConstColumn(actions_dag,
std::make_shared<DataTypeString>(), output_type->getName()));
+ result_node = toFunctionNode(actions_dag,
"accurateCastOrNull", args);
+ }
+ else if (isString(denull_input_type) && isInt(denull_output_type))
+ {
+ /// Spark cast(x as INT) if x is String -> CH cast(trim(x) as
INT)
+ /// Refer to
https://github.com/apache/incubator-gluten/issues/4956 and
https://github.com/apache/incubator-gluten/issues/8598
+ auto trim_str_arg = addConstColumn(actions_dag,
std::make_shared<DataTypeString>(), " \t\n\r\f");
+ args[0] = toFunctionNode(actions_dag, "trimBothSpark",
{args[0], trim_str_arg});
+ args.emplace_back(addConstColumn(actions_dag,
std::make_shared<DataTypeString>(), output_type->getName()));
+ result_node = toFunctionNode(actions_dag, "CAST", args);
+ }
+ else
+ {
+ /// Common process: CAST(input, type)
+ args.emplace_back(addConstColumn(actions_dag,
std::make_shared<DataTypeString>(), output_type->getName()));
+ result_node = toFunctionNode(actions_dag, "CAST", args);
+ }
+
+ actions_dag.addOrReplaceInOutputs(*result_node);
+ return result_node;
}
case substrait::Expression::RexTypeCase::kIfThen: {
@@ -434,138 +516,6 @@ ExpressionParser::NodeRawConstPtr
ExpressionParser::parseExpression(ActionsDAG &
}
}
-bool ExpressionParser::isValueFromNothingType(const substrait::Expression &
expr) const
-{
- const auto & cast_input = expr.cast().input();
- // null literal
- if (cast_input.has_literal() && cast_input.literal().has_null() &&
cast_input.literal().null().has_nothing())
- return true;
- else if (cast_input.has_scalar_function())
- {
- auto function_name =
getFunctionNameInSignature(cast_input.scalar_function());
- // empty map
- if (cast_input.scalar_function().output_type().has_map())
- {
- const auto & map_type =
cast_input.scalar_function().output_type().map();
- if (map_type.key().has_nothing() && map_type.value().has_nothing())
- return true;
- }
- // empty array
- else if (cast_input.scalar_function().output_type().has_list())
- {
- const auto & list_type =
cast_input.scalar_function().output_type().list();
- if (list_type.type().has_nothing())
- return true;
- }
- }
- return false;
-}
-
-ExpressionParser::NodeRawConstPtr ExpressionParser::parseCast(DB::ActionsDAG &
actions_dag, const substrait::Expression & cast_expr) const
-{
- if (isValueFromNothingType(cast_expr))
- return parseNothingValuesCast(actions_dag, cast_expr);
- return parseNormalValuesCast(actions_dag, cast_expr);
-}
-
-// Build a default value from the output type of `cast` when the `cast`'s
input is built from `nothing` type.
-// `nothing` type is wrapped in nullable in `TypeParser`, it could cause
nullability missmatch.
-ExpressionParser::NodeRawConstPtr
-ExpressionParser::parseNothingValuesCast(DB::ActionsDAG & actions_dag, const
substrait::Expression & cast_expr) const
-{
- auto ch_type = TypeParser::parseType(cast_expr.cast().type());
- // use the target type to create the default value.
- auto default_value = ch_type->getDefault();
- return addConstColumn(actions_dag, ch_type, default_value);
-}
-
-ExpressionParser::NodeRawConstPtr
-ExpressionParser::parseNormalValuesCast(DB::ActionsDAG & actions_dag, const
substrait::Expression & cast_expr) const
-{
- ActionsDAG::NodeRawConstPtrs args;
- const auto & input = cast_expr.cast().input();
- args.emplace_back(parseExpression(actions_dag, input));
-
- const auto & substrait_type = cast_expr.cast().type();
- const auto & input_type = args[0]->result_type;
- DataTypePtr denull_input_type = removeNullable(input_type);
- DataTypePtr output_type = TypeParser::parseType(substrait_type);
- DataTypePtr denull_output_type = removeNullable(output_type);
- const ActionsDAG::Node * result_node = nullptr;
- if (substrait_type.has_binary())
- {
- /// Spark cast(x as BINARY) -> CH reinterpretAsStringSpark(x)
- result_node = toFunctionNode(actions_dag, "reinterpretAsStringSpark",
args);
- }
- else if (isString(denull_input_type) && isDate32(denull_output_type))
- result_node = toFunctionNode(actions_dag, "sparkToDate", args);
- else if (isString(denull_input_type) && isDateTime64(denull_output_type))
- result_node = toFunctionNode(actions_dag, "sparkToDateTime", args);
- else if (isDecimal(denull_input_type) && isString(denull_output_type))
- {
- /// Spark cast(x as STRING) if x is Decimal -> CH toDecimalString(x,
scale)
- UInt8 scale = getDecimalScale(*denull_input_type);
- args.emplace_back(addConstColumn(actions_dag,
std::make_shared<DataTypeUInt8>(), Field(scale)));
- result_node = toFunctionNode(actions_dag, "toDecimalString", args);
- }
- else if (isFloat(denull_input_type) && isInt(denull_output_type))
- {
- String function_name = "sparkCastFloatTo" +
denull_output_type->getName();
- result_node = toFunctionNode(actions_dag, function_name, args);
- }
- else if (isFloat(denull_input_type) && isString(denull_output_type))
- result_node = toFunctionNode(actions_dag, "sparkCastFloatToString",
args);
- else if ((isDecimal(denull_input_type) ||
isNativeNumber(denull_input_type)) && substrait_type.has_decimal())
- {
- int precision = substrait_type.decimal().precision();
- int scale = substrait_type.decimal().scale();
- if (precision)
- {
- args.emplace_back(addConstColumn(actions_dag,
std::make_shared<DataTypeInt32>(), precision));
- args.emplace_back(addConstColumn(actions_dag,
std::make_shared<DataTypeInt32>(), scale));
- result_node = toFunctionNode(actions_dag,
"checkDecimalOverflowSparkOrNull", args);
- }
- }
- else if (isMap(denull_input_type) && isString(denull_output_type))
- {
- // ISSUE-7389: spark cast(map to string) has different behavior with
CH cast(map to string)
- auto map_input_type = std::static_pointer_cast<const
DataTypeMap>(denull_input_type);
- args.emplace_back(addConstColumn(actions_dag,
map_input_type->getKeyType(), map_input_type->getKeyType()->getDefault()));
- args.emplace_back(addConstColumn(actions_dag,
map_input_type->getValueType(), map_input_type->getValueType()->getDefault()));
- result_node = toFunctionNode(actions_dag, "sparkCastMapToString",
args);
- }
- else if (isArray(denull_input_type) && isString(denull_output_type))
- {
- // ISSUE-7602: spark cast(array to string) has different result with
CH cast(array to string)
- result_node = toFunctionNode(actions_dag, "sparkCastArrayToString",
args);
- }
- else if (isString(denull_input_type) && substrait_type.has_bool_())
- {
- /// cast(string to boolean)
- args.emplace_back(addConstColumn(actions_dag,
std::make_shared<DataTypeString>(), output_type->getName()));
- result_node = toFunctionNode(actions_dag, "accurateCastOrNull", args);
- }
- else if (isString(denull_input_type) && isInt(denull_output_type))
- {
- /// Spark cast(x as INT) if x is String -> CH cast(trim(x) as INT)
- /// Refer to https://github.com/apache/incubator-gluten/issues/4956
and https://github.com/apache/incubator-gluten/issues/8598
- auto trim_str_arg = addConstColumn(actions_dag,
std::make_shared<DataTypeString>(), " \t\n\r\f");
- args[0] = toFunctionNode(actions_dag, "trimBothSpark", {args[0],
trim_str_arg});
- args.emplace_back(addConstColumn(actions_dag,
std::make_shared<DataTypeString>(), output_type->getName()));
- result_node = toFunctionNode(actions_dag, "CAST", args);
- }
- else
- {
- /// Common process: CAST(input, type)
- args.emplace_back(addConstColumn(actions_dag,
std::make_shared<DataTypeString>(), output_type->getName()));
- result_node = toFunctionNode(actions_dag, "CAST", args);
- }
-
- actions_dag.addOrReplaceInOutputs(*result_node);
- return result_node;
-}
-
-
DB::ActionsDAG
ExpressionParser::expressionsToActionsDAG(const
std::vector<substrait::Expression> & expressions, const DB::Block & header)
const
{
diff --git a/cpp-ch/local-engine/Parser/ExpressionParser.h
b/cpp-ch/local-engine/Parser/ExpressionParser.h
index 6e7ae8a924..1e4a48282a 100644
--- a/cpp-ch/local-engine/Parser/ExpressionParser.h
+++ b/cpp-ch/local-engine/Parser/ExpressionParser.h
@@ -89,10 +89,5 @@ private:
static bool areEqualNodes(NodeRawConstPtr a, NodeRawConstPtr b);
NodeRawConstPtr findFirstStructureEqualNode(NodeRawConstPtr target, const
DB::ActionsDAG & actions_dag) const;
-
- NodeRawConstPtr parseCast(DB::ActionsDAG & actions_dag, const
substrait::Expression & cast_expr) const;
- bool isValueFromNothingType(const substrait::Expression & expr) const;
- NodeRawConstPtr parseNothingValuesCast(DB::ActionsDAG & actions_dag, const
substrait::Expression & cast_expr) const;
- NodeRawConstPtr parseNormalValuesCast(DB::ActionsDAG & actions_dag, const
substrait::Expression & cast_expr) const;
};
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]