This is an automated email from the ASF dual-hosted git repository.
zzcclp pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/gluten.git
The following commit(s) were added to refs/heads/main by this push:
new b8bb3b20dc [GLUTEN-12174][CH] Fix flatten nullable inner array row
handling (#12175)
b8bb3b20dc is described below
commit b8bb3b20dc0575c9769dd05c501b6ed089d97e7d
Author: lgbo <[email protected]>
AuthorDate: Fri May 29 18:05:49 2026 +0800
[GLUTEN-12174][CH] Fix flatten nullable inner array row handling (#12175)
SparkArrayFlatten handled Array(Nullable(Array(T))) by scanning all nested
inner arrays and returning a fully-null result column as soon as any inner
array was null. That made unrelated rows null, even though Spark flatten
semantics only null the outer row that contains a null inner array.
Build a result null map per outer row, mark only rows containing null inner
arrays as null, and keep non-null rows using the flattened array offsets. Add a
ClickHouse backend regression test where the first row contains a null inner
array and the second row remains non-null.
---
.../execution/GlutenFunctionValidateSuite.scala | 16 +++++++++++++++
.../local-engine/Functions/SparkArrayFlatten.cpp | 23 +++++++++++++++-------
2 files changed, 32 insertions(+), 7 deletions(-)
diff --git
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenFunctionValidateSuite.scala
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenFunctionValidateSuite.scala
index 66ea6c9866..07d44cc309 100644
---
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenFunctionValidateSuite.scala
+++
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenFunctionValidateSuite.scala
@@ -682,6 +682,22 @@ class GlutenFunctionValidateSuite extends
GlutenClickHouseWholeStageTransformerS
}
}
+ test("test flatten with nullable inner arrays") {
+ val sql =
+ """
+ |select id, flatten(arr)
+ |from (
+ | select id,
+ | if(id = 0,
+ | array(array(cast(id + 1 as int)), cast(null as array<int>)),
+ | array(array(cast(id + 1 as int)))) as arr
+ | from range(2)
+ |)
+ |order by id
+ |""".stripMargin
+ runQueryAndCompare(sql)(checkGlutenPlan[ProjectExecTransformer])
+ }
+
test("test common subexpression eliminate") {
def checkOperatorCount[T <: TransformSupport](count: Int)(df:
DataFrame)(implicit
tag: ClassTag[T]): Unit = {
diff --git a/cpp-ch/local-engine/Functions/SparkArrayFlatten.cpp
b/cpp-ch/local-engine/Functions/SparkArrayFlatten.cpp
index 96faa9d1dc..7ead48cac1 100644
--- a/cpp-ch/local-engine/Functions/SparkArrayFlatten.cpp
+++ b/cpp-ch/local-engine/Functions/SparkArrayFlatten.cpp
@@ -16,6 +16,7 @@
*/
#include <Columns/ColumnArray.h>
#include <Columns/ColumnNullable.h>
+#include <Columns/ColumnVector.h>
#include <DataTypes/DataTypeArray.h>
#include <DataTypes/DataTypeNullable.h>
#include <Functions/FunctionFactory.h>
@@ -107,19 +108,27 @@ result: Row 1: [1, 2, 3], Row2: [4]
const IColumn::Offsets * prev_offsets = &src_offsets;
const IColumn * prev_data = &src_col->getData();
bool nullable = prev_data->isNullable();
- // when array has null element, return null
+ ColumnUInt8::MutablePtr result_null_map;
+ // When an inner array is null, only the corresponding outer row is
null.
if (nullable)
{
const ColumnNullable * nullable_column =
checkAndGetColumn<ColumnNullable>(prev_data);
prev_data = nullable_column->getNestedColumnPtr().get();
- for (size_t i = 0; i < nullable_column->size(); i++)
+ result_null_map = ColumnUInt8::create(input_rows_count, 0);
+ auto & result_null_map_data = result_null_map->getData();
+ size_t prev_offset = 0;
+ for (size_t row = 0; row < input_rows_count; ++row)
{
- if (nullable_column->isNullAt(i))
+ const auto current_offset = src_offsets[row];
+ for (size_t i = prev_offset; i < current_offset; ++i)
{
- auto res= nullable_column->cloneEmpty();
- res->insertManyDefaults(input_rows_count);
- return res;
+ if (nullable_column->isNullAt(i))
+ {
+ result_null_map_data[row] = 1;
+ break;
+ }
}
+ prev_offset = current_offset;
}
}
if (isNothing(prev_data->getDataType()))
@@ -142,7 +151,7 @@ result: Row 1: [1, 2, 3], Row2: [4]
prev_data->getPtr(),
result_offsets_column ? std::move(result_offsets_column) :
src_col->getOffsetsPtr());
if (nullable)
- return makeNullable(res);
+ return ColumnNullable::create(std::move(res),
std::move(result_null_map));
return res;
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]