This is an automated email from the ASF dual-hosted git repository. wesm pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/arrow.git
commit 89163c59300e1c5ef722b358548ea6c844361a8b Author: Pindikura Ravindra <[email protected]> AuthorDate: Tue Sep 18 17:03:14 2018 +0530 [Gandiva] short-circuit regex startsW/endsW - for trivial patterns like startsWith/endsWith, short-circuit to avoid function-calls, pointer-indirections --- cpp/src/gandiva/expr_decomposer.cc | 14 ++++- cpp/src/gandiva/expr_decomposer.h | 3 + cpp/src/gandiva/function_registry.cc | 4 ++ cpp/src/gandiva/like_holder.cc | 47 +++++++++++++- cpp/src/gandiva/like_holder.h | 10 +++ cpp/src/gandiva/like_holder_test.cc | 44 ++++++++++++- cpp/src/gandiva/node.h | 86 ++++++++++++++------------ cpp/src/gandiva/precompiled/string_ops.cc | 46 +++++++++++--- cpp/src/gandiva/precompiled/string_ops_test.cc | 28 +++++++++ cpp/src/gandiva/precompiled/types.h | 9 +++ cpp/src/gandiva/tests/utf8_test.cc | 51 +++++++++++++++ cpp/src/gandiva/tree_expr_builder.cc | 6 +- 12 files changed, 294 insertions(+), 54 deletions(-) diff --git a/cpp/src/gandiva/expr_decomposer.cc b/cpp/src/gandiva/expr_decomposer.cc index 18f4021..a5eede6 100644 --- a/cpp/src/gandiva/expr_decomposer.cc +++ b/cpp/src/gandiva/expr_decomposer.cc @@ -46,9 +46,21 @@ Status ExprDecomposer::Visit(const FieldNode& node) { return Status::OK(); } +// Try and optimize a function node, by substituting with cheaper alternatives. +// eg. replacing 'like' with 'starts_with' can save function calls at evaluation +// time. +const FunctionNode ExprDecomposer::TryOptimize(const FunctionNode &node) { + if (node.descriptor()->name() == "like") { + return LikeHolder::TryOptimize(node); + } else { + return node; + } +} + // Decompose a field node - wherever possible, merge the validity vectors of the // child nodes. -Status ExprDecomposer::Visit(const FunctionNode& node) { +Status ExprDecomposer::Visit(const FunctionNode &in_node) { + auto node = TryOptimize(in_node); auto desc = node.descriptor(); FunctionSignature signature(desc->name(), desc->params(), desc->return_type()); const NativeFunction* native_function = registry_.LookupSignature(signature); diff --git a/cpp/src/gandiva/expr_decomposer.h b/cpp/src/gandiva/expr_decomposer.h index a6204f5..97b242a 100644 --- a/cpp/src/gandiva/expr_decomposer.h +++ b/cpp/src/gandiva/expr_decomposer.h @@ -59,6 +59,9 @@ class ExprDecomposer : public NodeVisitor { Status Visit(const LiteralNode& node) override; Status Visit(const BooleanNode& node) override; + // Optimize a function node, if possible. + const FunctionNode TryOptimize(const FunctionNode &node); + // stack of if nodes. class IfStackEntry { public: diff --git a/cpp/src/gandiva/function_registry.cc b/cpp/src/gandiva/function_registry.cc index 71de601..d265222 100644 --- a/cpp/src/gandiva/function_registry.cc +++ b/cpp/src/gandiva/function_registry.cc @@ -355,6 +355,10 @@ NativeFunction FunctionRegistry::pc_registry_[] = { VAR_LEN_TYPES(BINARY_RELATIONAL_SAFE_NULL_IF_NULL, greater_than), VAR_LEN_TYPES(BINARY_RELATIONAL_SAFE_NULL_IF_NULL, greater_than_or_equal_to), + BINARY_RELATIONAL_SAFE_NULL_IF_NULL(starts_with, utf8), + BINARY_RELATIONAL_SAFE_NULL_IF_NULL(ends_with, utf8), + BINARY_RELATIONAL_SAFE_NULL_IF_NULL(starts_with_plus_one, utf8), + BINARY_RELATIONAL_SAFE_NULL_IF_NULL(ends_with_plus_one, utf8), NativeFunction("like", DataTypeVector{utf8(), utf8()}, boolean(), true /*null_safe*/, RESULT_NULL_IF_NULL, "like_utf8_utf8", true /*needs_holder*/), diff --git a/cpp/src/gandiva/like_holder.cc b/cpp/src/gandiva/like_holder.cc index 273c5d2..6c35c3a 100644 --- a/cpp/src/gandiva/like_holder.cc +++ b/cpp/src/gandiva/like_holder.cc @@ -27,7 +27,52 @@ namespace gandiva { namespace helpers { #endif -Status LikeHolder::Make(const FunctionNode& node, std::shared_ptr<LikeHolder>* holder) { +RE2 LikeHolder::starts_with_regex_(R"((\w|\s)*\.\*)"); +RE2 LikeHolder::ends_with_regex_(R"(\.\*(\w|\s)*)"); +RE2 LikeHolder::starts_with_plus_one_regex_(R"((\w|\s)*\.)"); +RE2 LikeHolder::ends_with_plus_one_regex_(R"(\.(\w|\s)*)"); + +// Short-circuit pattern matches for the two common sub cases : +// - starts_with and ends_with. +const FunctionNode LikeHolder::TryOptimize(const FunctionNode &node) { + std::shared_ptr<LikeHolder> holder; + auto status = Make(node, &holder); + if (status.ok()) { + std::string &pattern = holder->pattern_; + auto literal_type = node.children().at(1)->return_type(); + + if (RE2::FullMatch(pattern, starts_with_regex_)) { + auto prefix = pattern.substr(0, pattern.length() - 2); // trim .* + auto prefix_node = + std::make_shared<LiteralNode>(literal_type, LiteralHolder(prefix), false); + return FunctionNode("starts_with", {node.children().at(0), prefix_node}, + node.return_type()); + } else if (RE2::FullMatch(pattern, ends_with_regex_)) { + auto suffix = pattern.substr(2); // skip .* + auto suffix_node = + std::make_shared<LiteralNode>(literal_type, LiteralHolder(suffix), false); + return FunctionNode("ends_with", {node.children().at(0), suffix_node}, + node.return_type()); + } else if (RE2::FullMatch(pattern, starts_with_plus_one_regex_)) { + auto prefix = pattern.substr(0, pattern.length() - 1); // trim . + auto prefix_node = + std::make_shared<LiteralNode>(literal_type, LiteralHolder(prefix), false); + return FunctionNode("starts_with_plus_one", {node.children().at(0), prefix_node}, + node.return_type()); + } else if (RE2::FullMatch(pattern, ends_with_plus_one_regex_)) { + auto suffix = pattern.substr(1); // skip . + auto suffix_node = + std::make_shared<LiteralNode>(literal_type, LiteralHolder(suffix), false); + return FunctionNode("ends_with_plus_one", {node.children().at(0), suffix_node}, + node.return_type()); + } + } + + // didn't hit any of the optimisation paths. return original. + return node; +} + +Status LikeHolder::Make(const FunctionNode &node, std::shared_ptr<LikeHolder> *holder) { if (node.children().size() != 2) { return Status::Invalid("'like' function requires two parameters"); } diff --git a/cpp/src/gandiva/like_holder.h b/cpp/src/gandiva/like_holder.h index 673a4b1..be8c928 100644 --- a/cpp/src/gandiva/like_holder.h +++ b/cpp/src/gandiva/like_holder.h @@ -41,6 +41,9 @@ class LikeHolder : public FunctionHolder { static Status Make(const std::string& sql_pattern, std::shared_ptr<LikeHolder>* holder); + // Try and optimise a function node with a "like" pattern. + static const FunctionNode TryOptimize(const FunctionNode &node); + /// Return true if the data matches the pattern. bool operator()(const std::string& data) { return RE2::FullMatch(data, regex_); } @@ -49,6 +52,13 @@ class LikeHolder : public FunctionHolder { std::string pattern_; // posix pattern string, to help debugging RE2 regex_; // compiled regex for the pattern + + static RE2 starts_with_regex_; // pre-compiled pattern for matching starts_with + static RE2 ends_with_regex_; // pre-compiled pattern for matching ends_with + static RE2 starts_with_plus_one_regex_; // pre-compiled pattern for matching + // starts_with_plus_one + static RE2 + ends_with_plus_one_regex_; // pre-compiled pattern for matching ends_with_plus_one }; #ifdef GDV_HELPERS diff --git a/cpp/src/gandiva/like_holder_test.cc b/cpp/src/gandiva/like_holder_test.cc index 13d05b2..f3f5bae 100644 --- a/cpp/src/gandiva/like_holder_test.cc +++ b/cpp/src/gandiva/like_holder_test.cc @@ -25,7 +25,15 @@ namespace gandiva { -class TestLikeHolder : public ::testing::Test {}; +class TestLikeHolder : public ::testing::Test { + public: + FunctionNode BuildLike(std::string pattern) { + auto field = std::make_shared<FieldNode>(arrow::field("in", arrow::utf8())); + auto pattern_node = + std::make_shared<LiteralNode>(arrow::utf8(), LiteralHolder(pattern), false); + return FunctionNode("like", {field, pattern_node}, arrow::boolean()); + } +}; TEST_F(TestLikeHolder, TestMatchAny) { std::shared_ptr<LikeHolder> like_holder; @@ -76,7 +84,39 @@ TEST_F(TestLikeHolder, TestRegexEscape) { EXPECT_EQ(res, "%hello_abc.def#"); } -int main(int argc, char** argv) { +TEST_F(TestLikeHolder, TestOptimise) { + // optimise for 'starts_with' + auto fnode = LikeHolder::TryOptimize(BuildLike("xy 123z%")); + EXPECT_EQ(fnode.descriptor()->name(), "starts_with"); + EXPECT_EQ(fnode.ToString(), "bool starts_with(utf8, (string) xy 123z)"); + + // optimise for 'ends_with' + fnode = LikeHolder::TryOptimize(BuildLike("%xyz")); + EXPECT_EQ(fnode.descriptor()->name(), "ends_with"); + EXPECT_EQ(fnode.ToString(), "bool ends_with(utf8, (string) xyz)"); + + // optimise for 'starts_with_plus_one + fnode = LikeHolder::TryOptimize(BuildLike("xyz_")); + EXPECT_EQ(fnode.ToString(), "bool starts_with_plus_one(utf8, (string) xyz)"); + + fnode = LikeHolder::TryOptimize(BuildLike("_xyz")); + EXPECT_EQ(fnode.ToString(), "bool ends_with_plus_one(utf8, (string) xyz)"); + + // no optimisation for others. + fnode = LikeHolder::TryOptimize(BuildLike("%xyz%")); + EXPECT_EQ(fnode.descriptor()->name(), "like"); + + fnode = LikeHolder::TryOptimize(BuildLike("_xyz_")); + EXPECT_EQ(fnode.descriptor()->name(), "like"); + + fnode = LikeHolder::TryOptimize(BuildLike("%xyz_")); + EXPECT_EQ(fnode.descriptor()->name(), "like"); + + fnode = LikeHolder::TryOptimize(BuildLike("x_yz%")); + EXPECT_EQ(fnode.descriptor()->name(), "like"); +} + +int main(int argc, char **argv) { ::testing::InitGoogleTest(&argc, argv); return RUN_ALL_TESTS(); } diff --git a/cpp/src/gandiva/node.h b/cpp/src/gandiva/node.h index c397362..fbb7495 100644 --- a/cpp/src/gandiva/node.h +++ b/cpp/src/gandiva/node.h @@ -38,12 +38,12 @@ class Node { virtual ~Node() = default; - const DataTypePtr& return_type() const { return return_type_; } + const DataTypePtr &return_type() const { return return_type_; } /// Derived classes should simply invoke the Visit api of the visitor. - virtual Status Accept(NodeVisitor& visitor) const = 0; + virtual Status Accept(NodeVisitor &visitor) const = 0; - virtual std::string ToString() = 0; + virtual std::string ToString() const = 0; protected: DataTypePtr return_type_; @@ -52,23 +52,37 @@ class Node { /// \brief Node in the expression tree, representing a literal. class LiteralNode : public Node { public: - LiteralNode(DataTypePtr type, const LiteralHolder& holder, bool is_null) + LiteralNode(DataTypePtr type, const LiteralHolder &holder, bool is_null) : Node(type), holder_(holder), is_null_(is_null) {} - Status Accept(NodeVisitor& visitor) const override { return visitor.Visit(*this); } + Status Accept(NodeVisitor &visitor) const override { return visitor.Visit(*this); } - const LiteralHolder& holder() const { return holder_; } + const LiteralHolder &holder() const { return holder_; } bool is_null() const { return is_null_; } - std::string ToString() override { + std::string ToString() const override { std::stringstream ss; - ss << "(" << return_type()->ToString() << ") "; + ss << "(const " << return_type()->ToString() << ") "; if (is_null()) { ss << std::string("null"); return ss.str(); } + ss << holder(); + // The default formatter prints in decimal can cause a loss in precision. so, + // print in hex. Can't use hexfloat since gcc 4.9 doesn't support it. + if (return_type()->id() == arrow::Type::DOUBLE) { + double dvalue = boost::get<double>(holder_); + uint64_t bits; + memcpy(&bits, &dvalue, sizeof(bits)); + ss << " raw(" << std::hex << bits << ")"; + } else if (return_type()->id() == arrow::Type::FLOAT) { + float fvalue = boost::get<float>(holder_); + uint32_t bits; + memcpy(&bits, &fvalue, sizeof(bits)); + ss << " raw(" << std::hex << bits << ")"; + } return ss.str(); } @@ -82,11 +96,13 @@ class FieldNode : public Node { public: explicit FieldNode(FieldPtr field) : Node(field->type()), field_(field) {} - Status Accept(NodeVisitor& visitor) const override { return visitor.Visit(*this); } + Status Accept(NodeVisitor &visitor) const override { return visitor.Visit(*this); } - const FieldPtr& field() const { return field_; } + const FieldPtr &field() const { return field_; } - std::string ToString() override { return field()->type()->name(); } + std::string ToString() const override { + return "(" + field()->type()->name() + ") " + field()->name(); + } private: FieldPtr field_; @@ -95,16 +111,14 @@ class FieldNode : public Node { /// \brief Node in the expression tree, representing a function. class FunctionNode : public Node { public: - FunctionNode(FuncDescriptorPtr descriptor, const NodeVector& children, - DataTypePtr retType) - : Node(retType), descriptor_(descriptor), children_(children) {} + FunctionNode(const std::string &name, const NodeVector &children, DataTypePtr retType); - Status Accept(NodeVisitor& visitor) const override { return visitor.Visit(*this); } + Status Accept(NodeVisitor &visitor) const override { return visitor.Visit(*this); } - const FuncDescriptorPtr& descriptor() const { return descriptor_; } - const NodeVector& children() const { return children_; } + const FuncDescriptorPtr &descriptor() const { return descriptor_; } + const NodeVector &children() const { return children_; } - std::string ToString() override { + std::string ToString() const override { std::stringstream ss; ss << descriptor()->return_type()->name() << " " << descriptor()->name() << "("; bool skip_comma = true; @@ -120,26 +134,20 @@ class FunctionNode : public Node { return ss.str(); } - /// Make a function node with params types specified by 'children', and - /// having return type ret_type. - static NodePtr MakeFunction(const std::string& name, const NodeVector& children, - DataTypePtr return_type); - private: FuncDescriptorPtr descriptor_; NodeVector children_; }; -inline NodePtr FunctionNode::MakeFunction(const std::string& name, - const NodeVector& children, - DataTypePtr return_type) { +inline FunctionNode::FunctionNode(const std::string &name, const NodeVector &children, + DataTypePtr return_type) + : Node(return_type), children_(children) { DataTypeVector param_types; - for (auto& child : children) { + for (auto &child : children) { param_types.push_back(child->return_type()); } - auto func_desc = FuncDescriptorPtr(new FuncDescriptor(name, param_types, return_type)); - return NodePtr(new FunctionNode(func_desc, children, return_type)); + descriptor_ = FuncDescriptorPtr(new FuncDescriptor(name, param_types, return_type)); } /// \brief Node in the expression tree, representing an if-else expression. @@ -151,13 +159,13 @@ class IfNode : public Node { then_node_(then_node), else_node_(else_node) {} - Status Accept(NodeVisitor& visitor) const override { return visitor.Visit(*this); } + Status Accept(NodeVisitor &visitor) const override { return visitor.Visit(*this); } - const NodePtr& condition() const { return condition_; } - const NodePtr& then_node() const { return then_node_; } - const NodePtr& else_node() const { return else_node_; } + const NodePtr &condition() const { return condition_; } + const NodePtr &then_node() const { return then_node_; } + const NodePtr &else_node() const { return else_node_; } - std::string ToString() override { + std::string ToString() const override { std::stringstream ss; ss << "if (" << condition()->ToString() << ") { "; ss << then_node()->ToString() << " } else { "; @@ -176,19 +184,19 @@ class BooleanNode : public Node { public: enum ExprType : char { AND, OR }; - BooleanNode(ExprType expr_type, const NodeVector& children) + BooleanNode(ExprType expr_type, const NodeVector &children) : Node(arrow::boolean()), expr_type_(expr_type), children_(children) {} - Status Accept(NodeVisitor& visitor) const override { return visitor.Visit(*this); } + Status Accept(NodeVisitor &visitor) const override { return visitor.Visit(*this); } ExprType expr_type() const { return expr_type_; } - const NodeVector& children() const { return children_; } + const NodeVector &children() const { return children_; } - std::string ToString() override { + std::string ToString() const override { std::stringstream ss; bool first = true; - for (auto& child : children_) { + for (auto &child : children_) { if (!first) { if (expr_type() == BooleanNode::AND) { ss << " && "; diff --git a/cpp/src/gandiva/precompiled/string_ops.cc b/cpp/src/gandiva/precompiled/string_ops.cc index 62b9e07..184c241 100644 --- a/cpp/src/gandiva/precompiled/string_ops.cc +++ b/cpp/src/gandiva/precompiled/string_ops.cc @@ -50,8 +50,8 @@ int32 mem_compare(const char* left, int32 left_len, const char* right, int32 rig } // Expand inner macro for all varlen types. -#define VAR_LEN_TYPES(INNER, NAME, OP) \ - INNER(NAME, utf8, OP) \ +#define VAR_LEN_OP_TYPES(INNER, NAME, OP) \ + INNER(NAME, utf8, OP) \ INNER(NAME, binary, OP) // Relational binary fns : left, right params are same, return is bool. @@ -62,11 +62,41 @@ int32 mem_compare(const char* left, int32 left_len, const char* right, int32 rig return mem_compare(left, left_len, right, right_len) OP 0; \ } -VAR_LEN_TYPES(BINARY_RELATIONAL, equal, ==) -VAR_LEN_TYPES(BINARY_RELATIONAL, not_equal, !=) -VAR_LEN_TYPES(BINARY_RELATIONAL, less_than, <) -VAR_LEN_TYPES(BINARY_RELATIONAL, less_than_or_equal_to, <=) -VAR_LEN_TYPES(BINARY_RELATIONAL, greater_than, >) -VAR_LEN_TYPES(BINARY_RELATIONAL, greater_than_or_equal_to, >=) +VAR_LEN_OP_TYPES(BINARY_RELATIONAL, equal, ==) +VAR_LEN_OP_TYPES(BINARY_RELATIONAL, not_equal, !=) +VAR_LEN_OP_TYPES(BINARY_RELATIONAL, less_than, <) +VAR_LEN_OP_TYPES(BINARY_RELATIONAL, less_than_or_equal_to, <=) +VAR_LEN_OP_TYPES(BINARY_RELATIONAL, greater_than, >) +VAR_LEN_OP_TYPES(BINARY_RELATIONAL, greater_than_or_equal_to, >=) + +// Expand inner macro for all varlen types. +#define VAR_LEN_TYPES(INNER) \ + INNER(utf8) \ + INNER(binary) + +FORCE_INLINE +bool starts_with_utf8_utf8(const char *data, int32 data_len, const char *prefix, + int32 prefix_len) { + return ((data_len >= prefix_len) && (memcmp(data, prefix, prefix_len) == 0)); +} + +FORCE_INLINE +bool ends_with_utf8_utf8(const char *data, int32 data_len, const char *suffix, + int32 suffix_len) { + return ((data_len >= suffix_len) && + (memcmp(data + data_len - suffix_len, suffix, suffix_len) == 0)); +} + +FORCE_INLINE +bool starts_with_plus_one_utf8_utf8(const char *data, int32 data_len, const char *prefix, + int32 prefix_len) { + return ((data_len == prefix_len + 1) && (memcmp(data, prefix, prefix_len) == 0)); +} + +FORCE_INLINE +bool ends_with_plus_one_utf8_utf8(const char *data, int32 data_len, const char *suffix, + int32 suffix_len) { + return ((data_len == suffix_len + 1) && (memcmp(data + 1, suffix, suffix_len) == 0)); +} } // extern "C" diff --git a/cpp/src/gandiva/precompiled/string_ops_test.cc b/cpp/src/gandiva/precompiled/string_ops_test.cc index f3e350e..b4f522c 100644 --- a/cpp/src/gandiva/precompiled/string_ops_test.cc +++ b/cpp/src/gandiva/precompiled/string_ops_test.cc @@ -37,4 +37,32 @@ TEST(TestStringOps, TestCompare) { EXPECT_GT(mem_compare(left, 7, right, 5), 0); } +TEST(TestStringOps, TestBeginsEnds) { + // starts_with + EXPECT_TRUE(starts_with_utf8_utf8("hello sir", 9, "hello", 5)); + EXPECT_TRUE(starts_with_utf8_utf8("hellos", 6, "hello", 5)); + EXPECT_TRUE(starts_with_utf8_utf8("hello", 5, "hello", 5)); + EXPECT_FALSE(starts_with_utf8_utf8("hell", 4, "hello", 5)); + EXPECT_FALSE(starts_with_utf8_utf8("world hello", 11, "hello", 5)); + + // ends_with + EXPECT_TRUE(ends_with_utf8_utf8("hello sir", 9, "sir", 3)); + EXPECT_TRUE(ends_with_utf8_utf8("ssir", 4, "sir", 3)); + EXPECT_TRUE(ends_with_utf8_utf8("sir", 3, "sir", 3)); + EXPECT_FALSE(ends_with_utf8_utf8("ir", 2, "sir", 3)); + EXPECT_FALSE(ends_with_utf8_utf8("hello", 5, "sir", 3)); + + // starts_with_plus_one + EXPECT_TRUE(starts_with_plus_one_utf8_utf8("hello ", 6, "hello", 5)); + EXPECT_FALSE(starts_with_plus_one_utf8_utf8("hello world", 11, "hello", 5)); + EXPECT_FALSE(starts_with_plus_one_utf8_utf8("hello", 5, "hello", 5)); + EXPECT_FALSE(starts_with_plus_one_utf8_utf8("hell", 4, "hello", 5)); + + // ends_with_plus_one + EXPECT_TRUE(ends_with_plus_one_utf8_utf8("gworld", 6, "world", 5)); + EXPECT_FALSE(ends_with_plus_one_utf8_utf8("hello world", 11, "world", 5)); + EXPECT_FALSE(ends_with_plus_one_utf8_utf8("world", 5, "world", 5)); + EXPECT_FALSE(ends_with_plus_one_utf8_utf8("worl", 4, "world", 5)); +} + } // namespace gandiva diff --git a/cpp/src/gandiva/precompiled/types.h b/cpp/src/gandiva/precompiled/types.h index 154b144..ab7e9c6 100644 --- a/cpp/src/gandiva/precompiled/types.h +++ b/cpp/src/gandiva/precompiled/types.h @@ -125,6 +125,15 @@ int32 mod_int64_int32(int64 left, int32 right); int64 divide_int64_int64(int64 in1, boolean is_valid1, int64 in2, boolean is_valid2, bool *out_valid); +bool starts_with_utf8_utf8(const char *data, int32 data_len, const char *prefix, + int32 prefix_len); +bool ends_with_utf8_utf8(const char *data, int32 data_len, const char *suffix, + int32 suffix_len); +bool starts_with_plus_one_utf8_utf8(const char *data, int32 data_len, const char *prefix, + int32 prefix_len); +bool ends_with_plus_one_utf8_utf8(const char *data, int32 data_len, const char *suffix, + int32 suffix_len); + } // extern "C" #endif // PRECOMPILED_TYPES_H diff --git a/cpp/src/gandiva/tests/utf8_test.cc b/cpp/src/gandiva/tests/utf8_test.cc index ed7715a..5e49a53 100644 --- a/cpp/src/gandiva/tests/utf8_test.cc +++ b/cpp/src/gandiva/tests/utf8_test.cc @@ -211,4 +211,55 @@ TEST_F(TestUtf8, TestLike) { EXPECT_ARROW_ARRAY_EQUALS(exp, outputs.at(0)); } +TEST_F(TestUtf8, TestBeginsEnds) { + // schema for input fields + auto field_a = field("a", utf8()); + auto schema = arrow::schema({field_a}); + + // output fields + auto res1 = field("res1", boolean()); + auto res2 = field("res2", boolean()); + + // build expressions. + // like(literal("spark%"), a) + // like(literal("%spark"), a) + + auto node_a = TreeExprBuilder::MakeField(field_a); + auto literal_begin = TreeExprBuilder::MakeStringLiteral("spark%"); + auto is_like1 = + TreeExprBuilder::MakeFunction("like", {node_a, literal_begin}, boolean()); + auto expr1 = TreeExprBuilder::MakeExpression(is_like1, res1); + + auto literal_end = TreeExprBuilder::MakeStringLiteral("%spark"); + auto is_like2 = TreeExprBuilder::MakeFunction("like", {node_a, literal_end}, boolean()); + auto expr2 = TreeExprBuilder::MakeExpression(is_like2, res2); + + // Build a projector for the expressions. + std::shared_ptr<Projector> projector; + Status status = Projector::Make(schema, {expr1, expr2}, &projector); + EXPECT_TRUE(status.ok()) << status.message(); + + // Create a row-batch with some sample data + int num_records = 4; + auto array_a = + MakeArrowArrayUtf8({"park", "sparkle", "bright spark and fire", "fiery spark"}, + {true, true, true, true}); + + // expected output + auto exp1 = MakeArrowArrayBool({false, true, false, false}, {true, true, true, true}); + auto exp2 = MakeArrowArrayBool({false, false, false, true}, {true, true, true, true}); + + // prepare input record batch + auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array_a}); + + // Evaluate expression + arrow::ArrayVector outputs; + status = projector->Evaluate(*in_batch, pool_, &outputs); + EXPECT_TRUE(status.ok()) << status.message(); + + // Validate results + EXPECT_ARROW_ARRAY_EQUALS(exp1, outputs.at(0)); + EXPECT_ARROW_ARRAY_EQUALS(exp2, outputs.at(1)); +} + } // namespace gandiva diff --git a/cpp/src/gandiva/tree_expr_builder.cc b/cpp/src/gandiva/tree_expr_builder.cc index b652483..3dbb8a0 100644 --- a/cpp/src/gandiva/tree_expr_builder.cc +++ b/cpp/src/gandiva/tree_expr_builder.cc @@ -103,7 +103,7 @@ NodePtr TreeExprBuilder::MakeFunction(const std::string& name, const NodeVector& if (result_type == nullptr) { return nullptr; } - return FunctionNode::MakeFunction(name, params, result_type); + return std::make_shared<FunctionNode>(name, params, result_type); } NodePtr TreeExprBuilder::MakeIf(NodePtr condition, NodePtr then_node, NodePtr else_node, @@ -147,7 +147,7 @@ ExpressionPtr TreeExprBuilder::MakeExpression(const std::string& function, auto node = MakeField(field); field_nodes.push_back(node); } - auto func_node = FunctionNode::MakeFunction(function, field_nodes, out_field->type()); + auto func_node = MakeFunction(function, field_nodes, out_field->type()); return MakeExpression(func_node, out_field); } @@ -170,7 +170,7 @@ ConditionPtr TreeExprBuilder::MakeCondition(const std::string& function, field_nodes.push_back(node); } - auto func_node = FunctionNode::MakeFunction(function, field_nodes, arrow::boolean()); + auto func_node = MakeFunction(function, field_nodes, arrow::boolean()); return ConditionPtr(new Condition(func_node)); }
