This is an automated email from the ASF dual-hosted git repository.

zhangzc pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git


The following commit(s) were added to refs/heads/main by this push:
     new 95f07e10e [CH] Support Logarithm function (#5184)
95f07e10e is described below

commit 95f07e10e51ca501bda4228edb0766bb02de0329
Author: exmy <xumov...@gmail.com>
AuthorDate: Fri Mar 29 10:20:38 2024 +0800

    [CH] Support Logarithm function (#5184)
    
    [CH] Support log function
---
 .../Parser/scalar_function_parser/ln.cpp           |  2 +-
 .../{logarithm.h => log.cpp}                       | 48 ++++++++++++----------
 .../Parser/scalar_function_parser/log10.cpp        |  2 +-
 .../Parser/scalar_function_parser/log1p.cpp        |  2 +-
 .../Parser/scalar_function_parser/log2.cpp         |  2 +-
 .../Parser/scalar_function_parser/logarithm.h      |  2 +-
 .../expression/ExpressionMappings.scala            |  3 +-
 .../utils/clickhouse/ClickHouseTestSettings.scala  |  4 --
 .../utils/clickhouse/ClickHouseTestSettings.scala  |  4 --
 .../glutenproject/expression/ExpressionNames.scala |  1 +
 10 files changed, 34 insertions(+), 36 deletions(-)

diff --git a/cpp-ch/local-engine/Parser/scalar_function_parser/ln.cpp 
b/cpp-ch/local-engine/Parser/scalar_function_parser/ln.cpp
index 452311483..081444ee5 100644
--- a/cpp-ch/local-engine/Parser/scalar_function_parser/ln.cpp
+++ b/cpp-ch/local-engine/Parser/scalar_function_parser/ln.cpp
@@ -28,7 +28,7 @@ public:
     static constexpr auto name = "log";
 
     String getName() const override { return name; }
-    String getCHFunctionName() const override { return "log"; }
+    String getCHFunctionName() const override { return name; }
     const DB::ActionsDAG::Node * getParameterLowerBound(ActionsDAGPtr & 
actions_dag, const DataTypePtr & data_type) const override
     {
         return addColumnToActionsDAG(actions_dag, data_type, 0.0);
diff --git a/cpp-ch/local-engine/Parser/scalar_function_parser/logarithm.h 
b/cpp-ch/local-engine/Parser/scalar_function_parser/log.cpp
similarity index 53%
copy from cpp-ch/local-engine/Parser/scalar_function_parser/logarithm.h
copy to cpp-ch/local-engine/Parser/scalar_function_parser/log.cpp
index 6b44af187..31d7be6b4 100644
--- a/cpp-ch/local-engine/Parser/scalar_function_parser/logarithm.h
+++ b/cpp-ch/local-engine/Parser/scalar_function_parser/log.cpp
@@ -32,46 +32,50 @@ namespace ErrorCodes
 
 namespace local_engine
 {
-class FunctionParserLogBase : public FunctionParser
+class FunctionParserLog : public FunctionParser
 {
 public:
-    explicit FunctionParserLogBase(SerializedPlanParser * plan_parser_) : 
FunctionParser(plan_parser_) {}
-    ~FunctionParserLogBase() override = default;
+    explicit FunctionParserLog(SerializedPlanParser * plan_parser_) : 
FunctionParser(plan_parser_) {}
+    ~FunctionParserLog() override = default;
 
-    virtual String getCHFunctionName() const { return "log"; }
-    virtual const DB::ActionsDAG::Node * getParameterLowerBound(ActionsDAGPtr 
&, const DataTypePtr &) const { return nullptr; }
+    static constexpr auto name = "logarithm";
+
+    String getName() const override { return name; }
 
     const ActionsDAG::Node * parse(
         const substrait::Expression_ScalarFunction & substrait_func,
         ActionsDAGPtr & actions_dag) const override
     {
         /*
-            parse log(x) as
-            if (x <= c)
+            parse log(x, y) as
+            if (x <= 0.0 || y <= 0.0)
                 null
             else
-                log(x)
+                ln(y) / ln(x)
         */
         auto parsed_args = parseFunctionArguments(substrait_func, "", 
actions_dag);
-        if (parsed_args.size() != 1)
-            throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, 
"Function {} requires exactly one arguments", getName());
+        if (parsed_args.size() != 2)
+            throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, 
"Function {} requires exactly two arguments", getName());
+
+        const auto * x_node = parsed_args[0];
+        const auto * y_node = parsed_args[1];
+
+        const auto * ln_x_node = toFunctionNode(actions_dag, "ln", {x_node});
+        const auto * ln_y_node = toFunctionNode(actions_dag, "ln", {y_node});
+        auto result_type = std::make_shared<DataTypeFloat64>();
 
-        const auto * arg_node = parsed_args[0];
-        
-        const std::string ch_function_name = getCHFunctionName();
-        const auto * log_node = toFunctionNode(actions_dag, ch_function_name, 
{arg_node});
-        auto nullable_result_type = makeNullable(log_node->result_type);
+        const auto * null_const_node = addColumnToActionsDAG(actions_dag, 
makeNullable(result_type), Field());
+        const auto * zero_const_node = addColumnToActionsDAG(actions_dag, 
result_type, 0.0);
 
-        const auto * null_const_node = addColumnToActionsDAG(actions_dag, 
nullable_result_type, Field());
-        const auto * lower_bound_node = getParameterLowerBound(actions_dag, 
arg_node->result_type);
-        if (!lower_bound_node)
-            throw Exception(ErrorCodes::NOT_IMPLEMENTED, "Vritual function {} 
may not implement for {}", "getParameterLowerBound", getName());
-        
-        const auto * le_node = toFunctionNode(actions_dag, "lessOrEquals", 
{arg_node, lower_bound_node});
-        const auto * result_node = toFunctionNode(actions_dag, "if", {le_node, 
null_const_node, log_node});
+        const auto * le_x_node = toFunctionNode(actions_dag, "lessOrEquals", 
{x_node, zero_const_node});
+        const auto * le_y_node = toFunctionNode(actions_dag, "lessOrEquals", 
{y_node, zero_const_node});
+        const auto * or_node = toFunctionNode(actions_dag, "or", {le_x_node, 
le_y_node});
+        const auto * divide_node = toFunctionNode(actions_dag, "divide", 
{ln_y_node, ln_x_node});
+        const auto * result_node = toFunctionNode(actions_dag, "if", {or_node, 
null_const_node, divide_node});
 
         return convertNodeTypeIfNeeded(substrait_func, result_node, 
actions_dag);
     }
 };
 
+static FunctionParserRegister<FunctionParserLog> register_log;
 }
diff --git a/cpp-ch/local-engine/Parser/scalar_function_parser/log10.cpp 
b/cpp-ch/local-engine/Parser/scalar_function_parser/log10.cpp
index 191ca1187..b62ef486d 100644
--- a/cpp-ch/local-engine/Parser/scalar_function_parser/log10.cpp
+++ b/cpp-ch/local-engine/Parser/scalar_function_parser/log10.cpp
@@ -28,7 +28,7 @@ public:
     static constexpr auto name = "log10";
 
     String getName() const override { return name; }
-    String getCHFunctionName() const override { return "log10"; }
+    String getCHFunctionName() const override { return name; }
     const DB::ActionsDAG::Node * getParameterLowerBound(ActionsDAGPtr & 
actions_dag, const DataTypePtr & data_type) const override
     {
         return addColumnToActionsDAG(actions_dag, data_type, 0.0);
diff --git a/cpp-ch/local-engine/Parser/scalar_function_parser/log1p.cpp 
b/cpp-ch/local-engine/Parser/scalar_function_parser/log1p.cpp
index d669c1eab..d7ad5aa8b 100644
--- a/cpp-ch/local-engine/Parser/scalar_function_parser/log1p.cpp
+++ b/cpp-ch/local-engine/Parser/scalar_function_parser/log1p.cpp
@@ -28,7 +28,7 @@ public:
     static constexpr auto name = "log1p";
 
     String getName() const override { return name; }
-    String getCHFunctionName() const override { return "log1p"; }
+    String getCHFunctionName() const override { return name; }
     const DB::ActionsDAG::Node * getParameterLowerBound(ActionsDAGPtr & 
actions_dag, const DataTypePtr & data_type) const override
     {
         return addColumnToActionsDAG(actions_dag, data_type, -1.0);
diff --git a/cpp-ch/local-engine/Parser/scalar_function_parser/log2.cpp 
b/cpp-ch/local-engine/Parser/scalar_function_parser/log2.cpp
index 463795be9..5520fa035 100644
--- a/cpp-ch/local-engine/Parser/scalar_function_parser/log2.cpp
+++ b/cpp-ch/local-engine/Parser/scalar_function_parser/log2.cpp
@@ -28,7 +28,7 @@ public:
     static constexpr auto name = "log2";
 
     String getName() const override { return name; }
-    String getCHFunctionName() const override { return "log2"; }
+    String getCHFunctionName() const override { return name; }
     const DB::ActionsDAG::Node * getParameterLowerBound(ActionsDAGPtr & 
actions_dag, const DataTypePtr & data_type) const override
     {
         return addColumnToActionsDAG(actions_dag, data_type, 0.0);
diff --git a/cpp-ch/local-engine/Parser/scalar_function_parser/logarithm.h 
b/cpp-ch/local-engine/Parser/scalar_function_parser/logarithm.h
index 6b44af187..59e22d8cb 100644
--- a/cpp-ch/local-engine/Parser/scalar_function_parser/logarithm.h
+++ b/cpp-ch/local-engine/Parser/scalar_function_parser/logarithm.h
@@ -38,7 +38,7 @@ public:
     explicit FunctionParserLogBase(SerializedPlanParser * plan_parser_) : 
FunctionParser(plan_parser_) {}
     ~FunctionParserLogBase() override = default;
 
-    virtual String getCHFunctionName() const { return "log"; }
+    virtual String getCHFunctionName() const = 0;
     virtual const DB::ActionsDAG::Node * getParameterLowerBound(ActionsDAGPtr 
&, const DataTypePtr &) const { return nullptr; }
 
     const ActionsDAG::Node * parse(
diff --git 
a/gluten-core/src/main/scala/io/glutenproject/expression/ExpressionMappings.scala
 
b/gluten-core/src/main/scala/io/glutenproject/expression/ExpressionMappings.scala
index 745d045fd..42619b407 100644
--- 
a/gluten-core/src/main/scala/io/glutenproject/expression/ExpressionMappings.scala
+++ 
b/gluten-core/src/main/scala/io/glutenproject/expression/ExpressionMappings.scala
@@ -132,9 +132,11 @@ object ExpressionMappings {
     Sig[Unhex](UNHEX),
     Sig[Hypot](HYPOT),
     Sig[Signum](SIGN),
+    Sig[Log10](LOG10),
     Sig[Log1p](LOG1P),
     Sig[Log2](LOG2),
     Sig[Log](LOG),
+    Sig[Logarithm](LOGARITHM),
     Sig[ToRadians](RADIANS),
     Sig[Greatest](GREATEST),
     Sig[Least](LEAST),
@@ -148,7 +150,6 @@ object ExpressionMappings {
     Sig[Atan2](ATAN2),
     Sig[Cos](COS),
     Sig[Cosh](COSH),
-    Sig[Log10](LOG10),
     Sig[ToDegrees](DEGREES),
     // SparkSQL DateTime functions
     Sig[Year](EXTRACT),
diff --git 
a/gluten-ut/spark32/src/test/scala/io/glutenproject/utils/clickhouse/ClickHouseTestSettings.scala
 
b/gluten-ut/spark32/src/test/scala/io/glutenproject/utils/clickhouse/ClickHouseTestSettings.scala
index 7d05c8b58..c9aaaee28 100644
--- 
a/gluten-ut/spark32/src/test/scala/io/glutenproject/utils/clickhouse/ClickHouseTestSettings.scala
+++ 
b/gluten-ut/spark32/src/test/scala/io/glutenproject/utils/clickhouse/ClickHouseTestSettings.scala
@@ -835,12 +835,8 @@ class ClickHouseTestSettings extends BackendTestSettings {
     .exclude("factorial")
     .exclude("rint")
     .exclude("expm1")
-    .exclude("log")
-    .exclude("log10")
-    .exclude("log2")
     .exclude("unhex")
     .exclude("atan2")
-    .exclude("binary log")
     .exclude("round/bround")
     .exclude("SPARK-37388: width_bucket")
     .excludeGlutenTest("round/bround")
diff --git 
a/gluten-ut/spark33/src/test/scala/io/glutenproject/utils/clickhouse/ClickHouseTestSettings.scala
 
b/gluten-ut/spark33/src/test/scala/io/glutenproject/utils/clickhouse/ClickHouseTestSettings.scala
index da029e575..d000dbecb 100644
--- 
a/gluten-ut/spark33/src/test/scala/io/glutenproject/utils/clickhouse/ClickHouseTestSettings.scala
+++ 
b/gluten-ut/spark33/src/test/scala/io/glutenproject/utils/clickhouse/ClickHouseTestSettings.scala
@@ -842,13 +842,9 @@ class ClickHouseTestSettings extends BackendTestSettings {
     .exclude("factorial")
     .exclude("rint")
     .exclude("expm1")
-    .exclude("log")
-    .exclude("log10")
     .exclude("bin")
-    .exclude("log2")
     .exclude("unhex")
     .exclude("atan2")
-    .exclude("binary log")
     .exclude("round/bround/floor/ceil")
     .excludeGlutenTest("round/bround/floor/ceil")
     .exclude("SPARK-36922: Support ANSI intervals for SIGN/SIGNUM")
diff --git 
a/shims/common/src/main/scala/io/glutenproject/expression/ExpressionNames.scala 
b/shims/common/src/main/scala/io/glutenproject/expression/ExpressionNames.scala
index 97fa914e9..56e10741a 100644
--- 
a/shims/common/src/main/scala/io/glutenproject/expression/ExpressionNames.scala
+++ 
b/shims/common/src/main/scala/io/glutenproject/expression/ExpressionNames.scala
@@ -156,6 +156,7 @@ object ExpressionNames {
   final val LOG1P = "log1p"
   final val LOG2 = "log2"
   final val LOG = "log"
+  final val LOGARITHM = "logarithm"
   final val RADIANS = "radians"
   final val GREATEST = "greatest"
   final val LEAST = "least"


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@gluten.apache.org
For additional commands, e-mail: commits-h...@gluten.apache.org

Reply via email to