This is an automated email from the ASF dual-hosted git repository.
zclllyybb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/master by this push:
new daab2c0e827 [Feature](function) Support function ARRAY_CROSS_PRODUCT
(#64031)
daab2c0e827 is described below
commit daab2c0e827d584a8a83857b5645d7a680996ad4
Author: linrrarity <[email protected]>
AuthorDate: Wed Jun 3 11:50:48 2026 +0800
[Feature](function) Support function ARRAY_CROSS_PRODUCT (#64031)
Issue Number: https://github.com/apache/doris/issues/48203
Related PR: https://github.com/apache/doris/pull/59223
doc: https://github.com/apache/doris-website/pull/3891
Problem Summary:
Support function `ARRAY_CROSS_PRODUCT`
```sql
Doris> SELECT CROSS_PRODUCT([1, 2, 3], [2, 3, 4]);
+-------------------------------------+
| CROSS_PRODUCT([1, 2, 3], [2, 3, 4]) |
+-------------------------------------+
| [-1, 2, -1] |
+-------------------------------------+
1 row in set (0.021 sec)
Doris> SELECT CROSS_PRODUCT([1, 2, 3], NULL);
+--------------------------------+
| CROSS_PRODUCT([1, 2, 3], NULL) |
+--------------------------------+
| NULL |
+--------------------------------+
1 row in set (0.009 sec)
Doris> SELECT CROSS_PRODUCT([1, NULL, 3], [1, 2, 3]);
ERROR 1105 (HY000): errCode = 2, detailMessage =
(127.0.0.1)[INVALID_ARGUMENT]function array_cross_product cannot have null
Doris> SELECT CROSS_PRODUCT([1, 2, 3, 4], [1, 2, 3, 4]);
ERROR 1105 (HY000): errCode = 2, detailMessage =
(127.0.0.1)[INVALID_ARGUMENT]function array_cross_product requires both input
arrays to have exactly 3 elements, got 4 and 4
```
---
.../array/function_array_cross_product.cpp | 190 ++++++++++++++
.../function/array/function_array_register.cpp | 2 +
.../function/function_array_cross_product_test.cpp | 285 +++++++++++++++++++++
.../doris/catalog/BuiltinScalarFunctions.java | 2 +
.../functions/executable/ArrayArithmetic.java | 57 +++++
.../functions/scalar/ArrayCrossProduct.java | 75 ++++++
.../expressions/visitor/ScalarFunctionVisitor.java | 5 +
.../array_functions/test_array_cross_product.out | 35 +++
.../test_array_cross_product.groovy | 149 +++++++++++
9 files changed, 800 insertions(+)
diff --git a/be/src/exprs/function/array/function_array_cross_product.cpp
b/be/src/exprs/function/array/function_array_cross_product.cpp
new file mode 100644
index 00000000000..a8fff7a64e1
--- /dev/null
+++ b/be/src/exprs/function/array/function_array_cross_product.cpp
@@ -0,0 +1,190 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <memory>
+
+#include "common/status.h"
+#include "core/assert_cast.h"
+#include "core/block/block.h"
+#include "core/block/column_numbers.h"
+#include "core/block/column_with_type_and_name.h"
+#include "core/column/column.h"
+#include "core/column/column_array.h"
+#include "core/column/column_const.h"
+#include "core/column/column_nullable.h"
+#include "core/column/column_vector.h"
+#include "core/data_type/data_type.h"
+#include "core/data_type/data_type_array.h"
+#include "core/data_type/data_type_nullable.h"
+#include "core/data_type/data_type_number.h"
+#include "exprs/function/array/function_array_utils.h"
+#include "exprs/function/function.h"
+#include "exprs/function/simple_function_factory.h"
+
+namespace doris {
+class FunctionContext;
+} // namespace doris
+
+namespace doris {
+
+class FunctionArrayCrossProduct : public IFunction {
+public:
+ using ColumnType = PrimitiveTypeTraits<TYPE_FLOAT>::ColumnType;
+
+ static constexpr auto name = "array_cross_product";
+
+ static FunctionPtr create() { return
std::make_shared<FunctionArrayCrossProduct>(); }
+
+ String get_name() const override { return name; }
+ size_t get_number_of_arguments() const override { return 2; }
+
+ bool use_default_implementation_for_nulls() const override { return false;
}
+
+ DataTypePtr get_return_type_impl(const DataTypes& arguments) const
override {
+ DataTypePtr result_type =
+
std::make_shared<DataTypeArray>(make_nullable(std::make_shared<DataTypeFloat32>()));
+ if (arguments[0]->is_nullable() || arguments[1]->is_nullable()) {
+ return make_nullable(result_type);
+ }
+ return result_type;
+ }
+
+ Status execute_impl(FunctionContext* context, Block& block, const
ColumnNumbers& arguments,
+ uint32_t result, size_t input_rows_count) const
override {
+ const auto& [left_column, left_const] =
+ unpack_if_const(block.get_by_position(arguments[0]).column);
+ const auto& [right_column, right_const] =
+ unpack_if_const(block.get_by_position(arguments[1]).column);
+
+ ColumnArrayExecutionDatas array_datas(2);
+ if (!extract_column_array_info(*left_column, array_datas[0]) ||
+ !extract_column_array_info(*right_column, array_datas[1]))
[[unlikely]] {
+ return Status::RuntimeError("execute failed, unsupported types for
function {}({}, {})",
+ get_name(),
+
block.get_by_position(arguments[0]).type->get_name(),
+
block.get_by_position(arguments[1]).type->get_name());
+ }
+ const auto& left_data = array_datas[0];
+ const auto& right_data = array_datas[1];
+
+ const auto& left_nested_data =
+ assert_cast<const
ColumnFloat32&>(*left_data.nested_col).get_data();
+ const auto& right_nested_data =
+ assert_cast<const
ColumnFloat32&>(*right_data.nested_col).get_data();
+
+ auto res_data = ColumnType::create();
+ auto& res_values = res_data->get_data();
+ res_values.reserve(input_rows_count * VECTOR_DIM);
+ auto res_offsets = ColumnArray::ColumnOffsets::create();
+ auto& offsets = res_offsets->get_data();
+ offsets.resize(input_rows_count);
+ auto result_nested_null_map = ColumnUInt8::create();
+ auto& result_nested_null_map_data = result_nested_null_map->get_data();
+ result_nested_null_map_data.reserve(input_rows_count * VECTOR_DIM);
+ auto result_null_map = ColumnUInt8::create(input_rows_count, 0);
+ auto& result_null_map_data = result_null_map->get_data();
+ size_t result_offset = 0;
+
+ for (size_t row = 0; row < input_rows_count; ++row) {
+ const auto left_row = index_check_const(row, left_const);
+ const auto right_row = index_check_const(row, right_const);
+ const bool is_null =
+ is_top_null(left_data, left_row) ||
is_top_null(right_data, right_row);
+ result_null_map_data[row] = is_null;
+ if (is_null) {
+ offsets[row] = result_offset;
+ continue;
+ }
+
+ size_t left_begin = (*left_data.offsets_ptr)[left_row - 1];
+ size_t right_begin = (*right_data.offsets_ptr)[right_row - 1];
+ auto dim1 = (*left_data.offsets_ptr)[left_row] - left_begin;
+ auto dim2 = (*right_data.offsets_ptr)[right_row] - right_begin;
+
+ RETURN_IF_ERROR(check_vector_dims(dim1, dim2));
+ if (has_nested_null(left_data, left_begin) ||
+ has_nested_null(right_data, right_begin)) {
+ return Status::InvalidArgument("function {} cannot have null",
get_name());
+ }
+
+ compute_cross_product(left_nested_data, left_begin,
right_nested_data, right_begin,
+ res_values, result_nested_null_map_data);
+ result_offset += VECTOR_DIM;
+ offsets[row] = result_offset;
+ }
+
+ auto result_column = ColumnArray::create(
+ ColumnNullable::create(std::move(res_data),
std::move(result_nested_null_map)),
+ std::move(res_offsets));
+ if (block.get_by_position(result).type->is_nullable()) {
+ block.replace_by_position(result,
ColumnNullable::create(std::move(result_column),
+
std::move(result_null_map)));
+ } else {
+ block.replace_by_position(result, std::move(result_column));
+ }
+ return Status::OK();
+ }
+
+private:
+ static constexpr size_t VECTOR_DIM = 3;
+
+ static bool is_top_null(const ColumnArrayExecutionData& data, size_t row) {
+ return data.array_nullmap_data && data.array_nullmap_data[row];
+ }
+
+ static bool has_nested_null(const ColumnArrayExecutionData& data, size_t
begin) {
+ return data.nested_nullmap_data &&
+ (data.nested_nullmap_data[begin] ||
data.nested_nullmap_data[begin + 1] ||
+ data.nested_nullmap_data[begin + 2]);
+ }
+
+ Status check_vector_dims(size_t dim1, size_t dim2) const {
+ if (dim1 != VECTOR_DIM || dim2 != VECTOR_DIM) {
+ return Status::InvalidArgument(
+ "function {} requires both input arrays to have exactly 3
elements, got {} "
+ "and {}",
+ get_name(), dim1, dim2);
+ }
+ return Status::OK();
+ }
+
+ static void compute_cross_product(const ColumnFloat32::Container&
left_nested_data,
+ size_t left_begin,
+ const ColumnFloat32::Container&
right_nested_data,
+ size_t right_begin,
ColumnFloat32::Container& res_values,
+ ColumnUInt8::Container&
result_nested_null_map_data) {
+ float x0 = left_nested_data[left_begin];
+ float x1 = left_nested_data[left_begin + 1];
+ float x2 = left_nested_data[left_begin + 2];
+ float y0 = right_nested_data[right_begin];
+ float y1 = right_nested_data[right_begin + 1];
+ float y2 = right_nested_data[right_begin + 2];
+ res_values.push_back(x1 * y2 - x2 * y1);
+ res_values.push_back(x2 * y0 - x0 * y2);
+ res_values.push_back(x0 * y1 - x1 * y0);
+ result_nested_null_map_data.push_back(0);
+ result_nested_null_map_data.push_back(0);
+ result_nested_null_map_data.push_back(0);
+ }
+};
+
+void register_function_array_cross_product(SimpleFunctionFactory& factory) {
+ factory.register_function<FunctionArrayCrossProduct>();
+ factory.register_alias("array_cross_product", "cross_product");
+}
+
+} // namespace doris
diff --git a/be/src/exprs/function/array/function_array_register.cpp
b/be/src/exprs/function/array/function_array_register.cpp
index 7eca80dc5db..75c1401dd95 100644
--- a/be/src/exprs/function/array/function_array_register.cpp
+++ b/be/src/exprs/function/array/function_array_register.cpp
@@ -57,6 +57,7 @@ void
register_function_array_filter_function(SimpleFunctionFactory&);
void register_function_array_splits(SimpleFunctionFactory&);
void register_function_array_contains_all(SimpleFunctionFactory&);
void register_function_array_match(SimpleFunctionFactory&);
+void register_function_array_cross_product(SimpleFunctionFactory&);
void register_function_array(SimpleFunctionFactory& factory) {
register_function_array_flatten(factory);
@@ -95,6 +96,7 @@ void register_function_array(SimpleFunctionFactory& factory) {
register_function_array_splits(factory);
register_function_array_contains_all(factory);
register_function_array_match(factory);
+ register_function_array_cross_product(factory);
}
} // namespace doris
diff --git a/be/test/exprs/function/function_array_cross_product_test.cpp
b/be/test/exprs/function/function_array_cross_product_test.cpp
new file mode 100644
index 00000000000..f53a4f04cb5
--- /dev/null
+++ b/be/test/exprs/function/function_array_cross_product_test.cpp
@@ -0,0 +1,285 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <gtest/gtest.h>
+
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "common/status.h"
+#include "core/block/block.h"
+#include "core/column/column_array.h"
+#include "core/column/column_const.h"
+#include "core/column/column_nullable.h"
+#include "core/data_type/data_type_array.h"
+#include "core/data_type/data_type_nullable.h"
+#include "core/data_type/data_type_number.h"
+#include "exprs/function/simple_function_factory.h"
+#include "testutil/function_utils.h"
+
+namespace doris {
+
+namespace {
+
+using FloatRows = std::vector<std::vector<float>>;
+
+template <PrimitiveType ElementType>
+ColumnPtr create_array_column(const FloatRows& rows) {
+ using NestedColumnType = typename
PrimitiveTypeTraits<ElementType>::ColumnType;
+ auto data = NestedColumnType::create();
+ auto offsets = ColumnArray::ColumnOffsets::create();
+
+ auto& data_values = data->get_data();
+ auto& offset_values = offsets->get_data();
+ size_t offset = 0;
+ for (const auto& row : rows) {
+ for (const auto value : row) {
+ data_values.push_back(static_cast<typename
NestedColumnType::value_type>(value));
+ }
+ offset += row.size();
+ offset_values.push_back(offset);
+ }
+
+ return ColumnArray::create(std::move(data), std::move(offsets));
+}
+
+DataTypePtr array_nullable_float_type() {
+ return
std::make_shared<DataTypeArray>(make_nullable(std::make_shared<DataTypeFloat32>()));
+}
+
+template <PrimitiveType ElementType>
+DataTypePtr array_type() {
+ return std::make_shared<DataTypeArray>(
+ std::make_shared<typename
PrimitiveTypeTraits<ElementType>::DataType>());
+}
+
+template <PrimitiveType ElementType>
+Status execute_cross_product(const std::string& func_name, ColumnPtr lhs,
ColumnPtr rhs,
+ size_t rows, Block* block, bool nullable_lhs =
false,
+ bool nullable_rhs = false, bool nullable_element
= false) {
+ auto input_type = nullable_element ? array_nullable_float_type() :
array_type<ElementType>();
+ auto lhs_type = nullable_lhs ? make_nullable(input_type) : input_type;
+ auto rhs_type = nullable_rhs ? make_nullable(input_type) : input_type;
+ auto result_type = nullable_lhs || nullable_rhs ?
make_nullable(array_nullable_float_type())
+ :
array_nullable_float_type();
+ block->insert({std::move(lhs), lhs_type, "lhs"});
+ block->insert({std::move(rhs), rhs_type, "rhs"});
+
+ auto function = SimpleFunctionFactory::instance().get_function(
+ func_name, block->get_columns_with_type_and_name(), result_type);
+ EXPECT_NE(function, nullptr);
+
+ FunctionUtils fn_utils(result_type, {lhs_type, rhs_type}, false);
+ auto* fn_ctx = fn_utils.get_fn_ctx();
+ RETURN_IF_ERROR(function->open(fn_ctx, FunctionContext::FRAGMENT_LOCAL));
+ RETURN_IF_ERROR(function->open(fn_ctx, FunctionContext::THREAD_LOCAL));
+ block->insert({nullptr, result_type, "result"});
+ auto st = function->execute(fn_ctx, *block, {0, 1}, 2, rows);
+ RETURN_IF_ERROR(function->close(fn_ctx, FunctionContext::THREAD_LOCAL));
+ RETURN_IF_ERROR(function->close(fn_ctx, FunctionContext::FRAGMENT_LOCAL));
+ return st;
+}
+
+void expect_array_rows(const Block& block, const FloatRows& expected) {
+ auto result_col_holder =
block.get_by_position(2).column->convert_to_full_column_if_const();
+ const IColumn* result_col = result_col_holder.get();
+ const ColumnUInt8::Container* result_null_map = nullptr;
+ if (result_col->is_nullable()) {
+ const auto& nullable_col = assert_cast<const
ColumnNullable&>(*result_col);
+ result_null_map = &nullable_col.get_null_map_data();
+ result_col = nullable_col.get_nested_column_ptr().get();
+ }
+ const auto& array_col = assert_cast<const ColumnArray&>(*result_col);
+ const IColumn* nested_col = &array_col.get_data();
+ if (nested_col->is_nullable()) {
+ nested_col = assert_cast<const
ColumnNullable&>(*nested_col).get_nested_column_ptr().get();
+ }
+ const auto& data = assert_cast<const ColumnFloat32&>(*nested_col);
+ const auto& values = data.get_data();
+ const auto& offsets = array_col.get_offsets();
+
+ size_t offset = 0;
+ ASSERT_EQ(expected.size(), offsets.size());
+ for (size_t row = 0; row < expected.size(); ++row) {
+ if (result_null_map && (*result_null_map)[row]) {
+ offset = offsets[row];
+ continue;
+ }
+ ASSERT_EQ(expected[row].size(), offsets[row] - offset);
+ for (size_t i = 0; i < expected[row].size(); ++i) {
+ EXPECT_FLOAT_EQ(expected[row][i], values[offset + i]);
+ }
+ offset = offsets[row];
+ }
+}
+
+void expect_top_nulls(const Block& block, const std::vector<bool>&
expected_nulls) {
+ const auto& nullable_col = assert_cast<const
ColumnNullable&>(*block.get_by_position(2).column);
+ ASSERT_EQ(expected_nulls.size(), nullable_col.size());
+ for (size_t i = 0; i < expected_nulls.size(); ++i) {
+ EXPECT_EQ(expected_nulls[i], nullable_col.is_null_at(i));
+ }
+}
+
+void expect_nested_data_size(const Block& block, size_t expected_size) {
+ const auto& nullable_col = assert_cast<const
ColumnNullable&>(*block.get_by_position(2).column);
+ const auto& array_col = assert_cast<const
ColumnArray&>(*nullable_col.get_nested_column_ptr());
+ const auto& nested_nullable_col = assert_cast<const
ColumnNullable&>(array_col.get_data());
+ EXPECT_EQ(expected_size, nested_nullable_col.get_nested_column().size());
+}
+
+} // namespace
+
+TEST(function_array_cross_product_test, basic_and_alias) {
+ FloatRows lhs = {{1.0F, 0.0F, 0.0F}, {-2.0F, 3.0F, 4.0F}};
+ FloatRows rhs = {{0.0F, 1.0F, 0.0F}, {5.0F, -6.0F, 7.0F}};
+ FloatRows expected = {{0.0F, 0.0F, 1.0F}, {45.0F, 34.0F, -3.0F}};
+
+ {
+ Block block;
+ auto st = execute_cross_product<TYPE_FLOAT>(
+ "array_cross_product", create_array_column<TYPE_FLOAT>(lhs),
+ create_array_column<TYPE_FLOAT>(rhs), lhs.size(), &block);
+ ASSERT_TRUE(st.ok()) << st;
+ expect_array_rows(block, expected);
+ }
+
+ {
+ Block block;
+ auto st = execute_cross_product<TYPE_FLOAT>(
+ "cross_product", create_array_column<TYPE_FLOAT>(lhs),
+ create_array_column<TYPE_FLOAT>(rhs), lhs.size(), &block);
+ ASSERT_TRUE(st.ok()) << st;
+ expect_array_rows(block, expected);
+ }
+}
+
+TEST(function_array_cross_product_test, partial_const_arguments) {
+ FloatRows const_lhs = {{1.0F, 2.0F, 3.0F}};
+ FloatRows rhs = {{4.0F, 5.0F, 6.0F}, {-1.0F, 0.0F, 1.0F}};
+ FloatRows expected_left_const = {{-3.0F, 6.0F, -3.0F}, {2.0F, -4.0F,
2.0F}};
+
+ {
+ Block block;
+ auto st = execute_cross_product<TYPE_FLOAT>(
+ "array_cross_product",
+
ColumnConst::create(create_array_column<TYPE_FLOAT>(const_lhs), rhs.size()),
+ create_array_column<TYPE_FLOAT>(rhs), rhs.size(), &block);
+ ASSERT_TRUE(st.ok()) << st;
+ expect_array_rows(block, expected_left_const);
+ }
+
+ FloatRows lhs = {{4.0F, 5.0F, 6.0F}, {-1.0F, 0.0F, 1.0F}};
+ FloatRows const_rhs = {{1.0F, 2.0F, 3.0F}};
+ FloatRows expected_right_const = {{3.0F, -6.0F, 3.0F}, {-2.0F, 4.0F,
-2.0F}};
+
+ {
+ Block block;
+ auto st = execute_cross_product<TYPE_FLOAT>(
+ "array_cross_product", create_array_column<TYPE_FLOAT>(lhs),
+
ColumnConst::create(create_array_column<TYPE_FLOAT>(const_rhs), lhs.size()),
+ lhs.size(), &block);
+ ASSERT_TRUE(st.ok()) << st;
+ expect_array_rows(block, expected_right_const);
+ }
+}
+
+TEST(function_array_cross_product_test, all_const_arguments) {
+ FloatRows lhs = {{1.0F, 2.0F, 3.0F}};
+ FloatRows rhs = {{4.0F, 5.0F, 6.0F}};
+
+ Block block;
+ auto st = execute_cross_product<TYPE_FLOAT>(
+ "array_cross_product",
ColumnConst::create(create_array_column<TYPE_FLOAT>(lhs), 2),
+ ColumnConst::create(create_array_column<TYPE_FLOAT>(rhs), 2), 2,
&block);
+ ASSERT_TRUE(st.ok()) << st;
+ expect_array_rows(block, {{-3.0F, 6.0F, -3.0F}, {-3.0F, 6.0F, -3.0F}});
+}
+
+TEST(function_array_cross_product_test, invalid_dimension) {
+ Block block;
+ auto st = execute_cross_product<TYPE_FLOAT>(
+ "array_cross_product", create_array_column<TYPE_FLOAT>({{1.0F,
2.0F}}),
+ create_array_column<TYPE_FLOAT>({{3.0F, 4.0F, 5.0F}}), 1, &block);
+ ASSERT_FALSE(st.ok());
+ EXPECT_TRUE(st.to_string().find("exactly 3 elements") !=
std::string::npos) << st.to_string();
+}
+
+TEST(function_array_cross_product_test, null_element_returns_error) {
+ auto data = ColumnFloat32::create();
+ auto& data_values = data->get_data();
+ data_values.push_back(1.0F);
+ data_values.push_back(2.0F);
+ data_values.push_back(3.0F);
+ auto null_map = ColumnUInt8::create();
+ auto& null_values = null_map->get_data();
+ null_values.push_back(0);
+ null_values.push_back(1);
+ null_values.push_back(0);
+ auto offsets = ColumnArray::ColumnOffsets::create();
+ offsets->get_data().push_back(3);
+ auto lhs = ColumnArray::create(ColumnNullable::create(std::move(data),
std::move(null_map)),
+ std::move(offsets));
+
+ Block block;
+ auto st = execute_cross_product<TYPE_FLOAT>(
+ "array_cross_product", std::move(lhs),
+ create_array_column<TYPE_FLOAT>({{4.0F, 5.0F, 6.0F}}), 1, &block,
false, false, true);
+ ASSERT_FALSE(st.ok());
+ EXPECT_TRUE(st.to_string().find("cannot have null") != std::string::npos)
<< st.to_string();
+}
+
+TEST(function_array_cross_product_test, top_null_array_returns_null) {
+ auto lhs_nested_data = ColumnFloat32::create();
+ auto& lhs_nested_values = lhs_nested_data->get_data();
+ lhs_nested_values.push_back(1.0F);
+ lhs_nested_values.push_back(0.0F);
+ lhs_nested_values.push_back(0.0F);
+ lhs_nested_values.push_back(9.0F);
+ lhs_nested_values.push_back(9.0F);
+ lhs_nested_values.push_back(9.0F);
+ auto lhs_nested_null_map = ColumnUInt8::create();
+ auto& lhs_nested_null_map_data = lhs_nested_null_map->get_data();
+ lhs_nested_null_map_data.push_back(0);
+ lhs_nested_null_map_data.push_back(0);
+ lhs_nested_null_map_data.push_back(0);
+ lhs_nested_null_map_data.push_back(0);
+ lhs_nested_null_map_data.push_back(1);
+ lhs_nested_null_map_data.push_back(0);
+ auto lhs_offsets = ColumnArray::ColumnOffsets::create();
+ lhs_offsets->get_data().push_back(3);
+ lhs_offsets->get_data().push_back(6);
+ auto lhs_nested = ColumnArray::create(
+ ColumnNullable::create(std::move(lhs_nested_data),
std::move(lhs_nested_null_map)),
+ std::move(lhs_offsets));
+ auto lhs_null_map = ColumnUInt8::create();
+ lhs_null_map->get_data().push_back(0);
+ lhs_null_map->get_data().push_back(1);
+ auto lhs = ColumnNullable::create(std::move(lhs_nested),
std::move(lhs_null_map));
+
+ Block block;
+ auto st = execute_cross_product<TYPE_FLOAT>(
+ "array_cross_product", std::move(lhs),
+ create_array_column<TYPE_FLOAT>({{0.0F, 1.0F, 0.0F}, {4.0F, 5.0F,
6.0F}}), 2, &block,
+ true, false, true);
+ ASSERT_TRUE(st.ok()) << st;
+ expect_top_nulls(block, {false, true});
+ expect_nested_data_size(block, 3);
+}
+
+} // namespace doris
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/catalog/BuiltinScalarFunctions.java
b/fe/fe-core/src/main/java/org/apache/doris/catalog/BuiltinScalarFunctions.java
index 4478bc9bcfe..af420fecd3b 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/catalog/BuiltinScalarFunctions.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/catalog/BuiltinScalarFunctions.java
@@ -45,6 +45,7 @@ import
org.apache.doris.nereids.trees.expressions.functions.scalar.ArrayConcat;
import
org.apache.doris.nereids.trees.expressions.functions.scalar.ArrayContains;
import
org.apache.doris.nereids.trees.expressions.functions.scalar.ArrayContainsAll;
import org.apache.doris.nereids.trees.expressions.functions.scalar.ArrayCount;
+import
org.apache.doris.nereids.trees.expressions.functions.scalar.ArrayCrossProduct;
import org.apache.doris.nereids.trees.expressions.functions.scalar.ArrayCumSum;
import
org.apache.doris.nereids.trees.expressions.functions.scalar.ArrayDifference;
import
org.apache.doris.nereids.trees.expressions.functions.scalar.ArrayDistinct;
@@ -623,6 +624,7 @@ public class BuiltinScalarFunctions implements
FunctionHelper {
scalar(ArrayContains.class, "array_contains"),
scalar(ArrayContainsAll.class, "array_contains_all", "hasSubstr"),
scalar(ArrayCount.class, "array_count"),
+ scalar(ArrayCrossProduct.class, "array_cross_product",
"cross_product"),
scalar(ArrayCumSum.class, "array_cum_sum"),
scalar(ArrayDifference.class, "array_difference"),
scalar(ArrayDistinct.class, "array_distinct"),
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/executable/ArrayArithmetic.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/executable/ArrayArithmetic.java
index d1cd087fce6..1f3f9ffaf8d 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/executable/ArrayArithmetic.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/executable/ArrayArithmetic.java
@@ -24,6 +24,10 @@ import
org.apache.doris.nereids.trees.expressions.literal.ArrayLiteral;
import org.apache.doris.nereids.trees.expressions.literal.FloatLiteral;
import org.apache.doris.nereids.trees.expressions.literal.Literal;
import org.apache.doris.nereids.trees.expressions.literal.NullLiteral;
+import org.apache.doris.nereids.types.ArrayType;
+import org.apache.doris.nereids.types.FloatType;
+
+import com.google.common.collect.ImmutableList;
import java.util.List;
@@ -32,6 +36,59 @@ import java.util.List;
*/
public class ArrayArithmetic {
+ /**
+ * Compute the cross product between two 3D float arrays.
+ */
+ @ExecFunction(name = "array_cross_product")
+ public static Expression arrayCrossProduct(ArrayLiteral array1,
ArrayLiteral array2) {
+ return crossProduct("array_cross_product", array1, array2);
+ }
+
+ /**
+ * Alias for array_cross_product.
+ */
+ @ExecFunction(name = "cross_product")
+ public static Expression crossProduct(ArrayLiteral array1, ArrayLiteral
array2) {
+ return crossProduct("cross_product", array1, array2);
+ }
+
+ private static Expression crossProduct(String functionName, ArrayLiteral
array1, ArrayLiteral array2) {
+ List<Literal> items1 = array1.getValue();
+ List<Literal> items2 = array2.getValue();
+ if (items1.size() != 3 || items2.size() != 3) {
+ throw new AnalysisException("function " + functionName
+ + " requires both input arrays to have exactly 3 elements,
got "
+ + items1.size() + " and " + items2.size());
+ }
+ validateNoNull(functionName, items1);
+ validateNoNull(functionName, items2);
+
+ float x0 = floatValue(items1.get(0));
+ float x1 = floatValue(items1.get(1));
+ float x2 = floatValue(items1.get(2));
+ float y0 = floatValue(items2.get(0));
+ float y1 = floatValue(items2.get(1));
+ float y2 = floatValue(items2.get(2));
+
+ return new ArrayLiteral(ImmutableList.of(
+ new FloatLiteral(x1 * y2 - x2 * y1),
+ new FloatLiteral(x2 * y0 - x0 * y2),
+ new FloatLiteral(x0 * y1 - x1 * y0)),
+ ArrayType.of(FloatType.INSTANCE));
+ }
+
+ private static void validateNoNull(String functionName, List<Literal>
items) {
+ for (Literal item : items) {
+ if (item instanceof NullLiteral) {
+ throw new AnalysisException("function " + functionName + "
cannot have null");
+ }
+ }
+ }
+
+ private static float floatValue(Literal literal) {
+ return ((Number) literal.getValue()).floatValue();
+ }
+
/**
* Compute cosine similarity between two float arrays.
* cosine_similarity(x, y) = dot(x, y) / (||x|| * ||y||)
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/ArrayCrossProduct.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/ArrayCrossProduct.java
new file mode 100644
index 00000000000..770494d013e
--- /dev/null
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/ArrayCrossProduct.java
@@ -0,0 +1,75 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.nereids.trees.expressions.functions.scalar;
+
+import org.apache.doris.catalog.FunctionSignature;
+import org.apache.doris.nereids.trees.expressions.Expression;
+import
org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature;
+import org.apache.doris.nereids.trees.expressions.functions.PropagateNullable;
+import org.apache.doris.nereids.trees.expressions.shape.BinaryExpression;
+import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
+import org.apache.doris.nereids.types.ArrayType;
+import org.apache.doris.nereids.types.FloatType;
+
+import com.google.common.base.Preconditions;
+import com.google.common.collect.ImmutableList;
+
+import java.util.List;
+
+/**
+ * array_cross_product function
+ */
+public class ArrayCrossProduct extends ScalarFunction implements
ExplicitlyCastableSignature,
+ BinaryExpression, PropagateNullable {
+
+ public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
+ FunctionSignature.ret(ArrayType.of(FloatType.INSTANCE))
+ .args(ArrayType.of(FloatType.INSTANCE),
ArrayType.of(FloatType.INSTANCE))
+ );
+
+ /**
+ * constructor with 2 arguments.
+ */
+ public ArrayCrossProduct(Expression arg0, Expression arg1) {
+ super("array_cross_product", arg0, arg1);
+ }
+
+ /** constructor for withChildren and reuse signature */
+ private ArrayCrossProduct(ScalarFunctionParams functionParams) {
+ super(functionParams);
+ }
+
+ /**
+ * withChildren.
+ */
+ @Override
+ public ArrayCrossProduct withChildren(List<Expression> children) {
+ Preconditions.checkArgument(children.size() == 2);
+ return new ArrayCrossProduct(getFunctionParams(children));
+ }
+
+ @Override
+ public <R, C> R accept(ExpressionVisitor<R, C> visitor, C context) {
+ return visitor.visitArrayCrossProduct(this, context);
+ }
+
+ @Override
+ public List<FunctionSignature> getSignatures() {
+ return SIGNATURES;
+ }
+}
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/visitor/ScalarFunctionVisitor.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/visitor/ScalarFunctionVisitor.java
index 1f82a48f1d7..2bc571e87a9 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/visitor/ScalarFunctionVisitor.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/visitor/ScalarFunctionVisitor.java
@@ -47,6 +47,7 @@ import
org.apache.doris.nereids.trees.expressions.functions.scalar.ArrayConcat;
import
org.apache.doris.nereids.trees.expressions.functions.scalar.ArrayContains;
import
org.apache.doris.nereids.trees.expressions.functions.scalar.ArrayContainsAll;
import org.apache.doris.nereids.trees.expressions.functions.scalar.ArrayCount;
+import
org.apache.doris.nereids.trees.expressions.functions.scalar.ArrayCrossProduct;
import org.apache.doris.nereids.trees.expressions.functions.scalar.ArrayCumSum;
import
org.apache.doris.nereids.trees.expressions.functions.scalar.ArrayDifference;
import
org.apache.doris.nereids.trees.expressions.functions.scalar.ArrayDistinct;
@@ -677,6 +678,10 @@ public interface ScalarFunctionVisitor<R, C> {
return visitScalarFunction(arrayCount, context);
}
+ default R visitArrayCrossProduct(ArrayCrossProduct arrayCrossProduct, C
context) {
+ return visitScalarFunction(arrayCrossProduct, context);
+ }
+
default R visitArrayCumSum(ArrayCumSum arrayCumSum, C context) {
return visitScalarFunction(arrayCumSum, context);
}
diff --git
a/regression-test/data/query_p0/sql_functions/array_functions/test_array_cross_product.out
b/regression-test/data/query_p0/sql_functions/array_functions/test_array_cross_product.out
new file mode 100644
index 00000000000..74b23712c70
--- /dev/null
+++
b/regression-test/data/query_p0/sql_functions/array_functions/test_array_cross_product.out
@@ -0,0 +1,35 @@
+-- This file is automatically generated. You should know what you did if you
want to edit this
+-- !array_cross_product_1 --
+[-5.875, -35, -21.25]
+
+-- !array_cross_product_2 --
+[-5.875, -511.9688, -271.9375]
+
+-- !array_cross_product_4 --
+\N
+
+-- !array_cross_product_5 --
+[635, 1277, 640]
+
+-- !array_cross_product_6 --
+[-311817, 3.93878e+07, 3.30314e+07]
+
+-- !array_cross_product_7 --
+[4.160573e+07, 6.614641e+08, 1.963683e+09]
+
+-- !array_cross_product_8 --
+[2.307395e+09, 1.011772e+12, 1.219063e+11]
+
+-- !array_cross_product_9 --
+[-3.000001e+18, -6e+18, -3e+18]
+
+-- !array_cross_product_10 --
+1 [-5.875, -35, -21.25] [-3.000001e+18, -6e+18, -3e+18]
+3 \N \N
+4 [124.5625, -0.875, -3501.75] [-3e+18, -6e+18, -3e+18]
+
+-- !array_cross_product_11 --
+1 [635, 1277, 640] [-311817, 3.93878e+07, 3.30314e+07]
[4.160573e+07, 6.614641e+08, 1.963683e+09] [2.307395e+09, 1.011772e+12,
1.219063e+11] [-3.000001e+18, -6e+18, -3e+18]
+3 \N \N \N \N \N
+4 [-3, -6, -3] [4.208e+08, 8.5959e+08, 4.5678e+08] [6.847411e+08,
4.247816e+09, 2.987408e+09] [7.036868e+13, 7.036871e+13, 1.759219e+13]
[-3e+18, -6e+18, -3e+18]
+
diff --git
a/regression-test/suites/query_p0/sql_functions/array_functions/test_array_cross_product.groovy
b/regression-test/suites/query_p0/sql_functions/array_functions/test_array_cross_product.groovy
new file mode 100644
index 00000000000..c24c5fa4ce3
--- /dev/null
+++
b/regression-test/suites/query_p0/sql_functions/array_functions/test_array_cross_product.groovy
@@ -0,0 +1,149 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+suite("test_array_cross_product") {
+ qt_array_cross_product_1 """
+ SELECT array_cross_product([2.5, -3.0, 4.25], [-7.5, 0.5, 1.25])
+ """
+
+ qt_array_cross_product_2 """
+ SELECT cross_product([-0.125, 8.5, -16.0], [32.0, -0.5, 0.25])
+ """
+
+ qt_array_cross_product_4 """
+ SELECT array_cross_product(CAST(NULL AS ARRAY<FLOAT>), [4.0, 5.0, 6.0])
+ """
+
+ qt_array_cross_product_5 """
+ SELECT array_cross_product(CAST([-128, 0, 127] AS ARRAY<TINYINT>),
+ CAST([3, -5, 7] AS ARRAY<TINYINT>))
+ """
+
+ qt_array_cross_product_6 """
+ SELECT array_cross_product(CAST([-32768, 123, -456] AS
ARRAY<SMALLINT>),
+ CAST([789, -1011, 1213] AS ARRAY<SMALLINT>))
+ """
+
+ qt_array_cross_product_7 """
+ SELECT array_cross_product(CAST([123456, -7890, 42] AS ARRAY<INT>),
+ CAST([-314, 15926, -5358] AS ARRAY<INT>))
+ """
+
+ qt_array_cross_product_8 """
+ SELECT array_cross_product(CAST([-9000000, 12345, 67890] AS
ARRAY<BIGINT>),
+ CAST([24680, -13579, 112233] AS
ARRAY<BIGINT>))
+ """
+
+ qt_array_cross_product_9 """
+ SELECT array_cross_product(CAST([1000000001, -2000000002, 3000000003]
AS ARRAY<LARGEINT>),
+ CAST([-4000000004, 5000000005, -6000000006]
AS ARRAY<LARGEINT>))
+ """
+
+ testFoldConst("SELECT array_cross_product([2.5, -3.0, 4.25], [-7.5, 0.5,
1.25])")
+ testFoldConst("SELECT cross_product([-11, 0, 13], [17, -19, 23])")
+
+ sql "DROP TABLE IF EXISTS test_array_cross_product_table"
+ sql """
+ CREATE TABLE test_array_cross_product_table (
+ id INT,
+ lhs ARRAY<FLOAT>,
+ rhs ARRAY<FLOAT>,
+ tiny_lhs ARRAY<TINYINT>,
+ tiny_rhs ARRAY<TINYINT>,
+ small_lhs ARRAY<SMALLINT>,
+ small_rhs ARRAY<SMALLINT>,
+ int_lhs ARRAY<INT>,
+ int_rhs ARRAY<INT>,
+ big_lhs ARRAY<BIGINT>,
+ big_rhs ARRAY<BIGINT>,
+ large_lhs ARRAY<LARGEINT>,
+ large_rhs ARRAY<LARGEINT>
+ ) ENGINE=OLAP
+ DUPLICATE KEY(id)
+ DISTRIBUTED BY HASH(id) BUCKETS 1
+ PROPERTIES (
+ "replication_num" = "1"
+ )
+ """
+
+ sql """
+ INSERT INTO test_array_cross_product_table VALUES
+ (1, [2.5, -3.0, 4.25], [-7.5, 0.5, 1.25],
+ [-128, 0, 127], [3, -5, 7],
+ [-32768, 123, -456], [789, -1011, 1213],
+ [123456, -7890, 42], [-314, 15926, -5358],
+ [-9000000, 12345, 67890], [24680, -13579, 112233],
+ [1000000001, -2000000002, 3000000003], [-4000000004, 5000000005,
-6000000006]),
+ (2, [1.0, NULL, -3.0], [-4.0, 0.0, 6.5],
+ [1, NULL, -3], [-4, 0, 6],
+ [1, NULL, -3], [-4, 0, 6],
+ [1, NULL, -3], [-4, 0, 6],
+ [1, NULL, -3], [-4, 0, 6],
+ [1, NULL, -3], [-4, 0, 6]),
+ (3, NULL, [-4.0, 0.0, 6.5],
+ NULL, [-4, 0, 6],
+ NULL, [-4, 0, 6],
+ NULL, [-4, 0, 6],
+ NULL, [-4, 0, 6],
+ NULL, [-4, 0, 6]),
+ (4, [-0.0, 1000.5, -0.25], [3.5, -2.0, 0.125],
+ [7, -8, 9], [-10, 11, -12],
+ [30000, -20000, 10000], [-12345, 23456, -32768],
+ [-7654321, 1234567, -999], [333, -444, 555],
+ [2147483647, -2147483648, 4096], [-8192, 16384, -32768],
+ [-9000000000, 8000000000, -7000000000], [6000000000, -5000000000,
4000000000])
+ """
+
+ qt_array_cross_product_10 """
+ SELECT id, array_cross_product(lhs, rhs), cross_product(large_lhs,
large_rhs)
+ FROM test_array_cross_product_table
+ WHERE id IN (1, 3, 4)
+ ORDER BY id
+ """
+
+ qt_array_cross_product_11 """
+ SELECT id,
+ array_cross_product(tiny_lhs, tiny_rhs),
+ array_cross_product(small_lhs, small_rhs),
+ array_cross_product(int_lhs, int_rhs),
+ array_cross_product(big_lhs, big_rhs),
+ array_cross_product(large_lhs, large_rhs)
+ FROM test_array_cross_product_table
+ WHERE id IN (1, 3, 4)
+ ORDER BY id
+ """
+
+ test {
+ sql "SELECT array_cross_product([-11, NULL, 13], [17, -19, 23])"
+ exception "cannot have null"
+ }
+
+ test {
+ sql "SELECT cross_product([-11, NULL, 13], [17, -19, 23])"
+ exception "cannot have null"
+ }
+
+ test {
+ sql "SELECT array_cross_product(lhs, rhs) FROM
test_array_cross_product_table WHERE id = 2"
+ exception "cannot have null"
+ }
+
+ test {
+ sql "SELECT array_cross_product([1.0, -2.0], [3.0, -4.0, 5.0])"
+ exception "requires both input arrays to have exactly 3 elements"
+ }
+}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]