This is an automated email from the ASF dual-hosted git repository.

apitrou pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git


The following commit(s) were added to refs/heads/master by this push:
     new 49c70cc  ARROW-4204: [Gandiva] add support for decimal subtract
49c70cc is described below

commit 49c70ccfa70f041e7830628ef6a57f1fbdf131a9
Author: Pindikura Ravindra <[email protected]>
AuthorDate: Thu Feb 14 17:27:49 2019 +0100

    ARROW-4204: [Gandiva] add support for decimal subtract
    
    Author: Pindikura Ravindra <[email protected]>
    
    Closes #3636 from pravindra/dsub and squashes the following commits:
    
    ee60c009 <Pindikura Ravindra> ARROW-4204:  add support for decimal subtract
---
 cpp/src/gandiva/decimal_ir.cc                   | 51 ++++++++++++++++++++-
 cpp/src/gandiva/decimal_ir.h                    |  3 ++
 cpp/src/gandiva/function_registry_arithmetic.cc |  1 +
 cpp/src/gandiva/precompiled/decimal_ops.cc      |  5 +++
 cpp/src/gandiva/precompiled/decimal_ops.h       |  5 +++
 cpp/src/gandiva/precompiled/decimal_ops_test.cc | 59 ++++++++++++++++++++++---
 cpp/src/gandiva/tests/decimal_single_test.cc    | 54 +++++++++++++++++-----
 7 files changed, 160 insertions(+), 18 deletions(-)

diff --git a/cpp/src/gandiva/decimal_ir.cc b/cpp/src/gandiva/decimal_ir.cc
index d10158a..f51f512 100644
--- a/cpp/src/gandiva/decimal_ir.cc
+++ b/cpp/src/gandiva/decimal_ir.cc
@@ -307,6 +307,52 @@ Status DecimalIR::BuildAdd() {
   return Status::OK();
 }
 
+Status DecimalIR::BuildSubtract() {
+  // Create fn prototype :
+  // int128_t
+  // subtract_decimal128_decimal128(int128_t x_value, int32_t x_precision, 
int32_t
+  // x_scale,
+  //                           int128_t y_value, int32_t y_precision, int32_t 
y_scale
+  //                           int32_t out_precision, int32_t out_scale)
+  auto i32 = types()->i32_type();
+  auto i128 = types()->i128_type();
+  auto function = BuildFunction("subtract_decimal128_decimal128", i128,
+                                {
+                                    {"x_value", i128},
+                                    {"x_precision", i32},
+                                    {"x_scale", i32},
+                                    {"y_value", i128},
+                                    {"y_precision", i32},
+                                    {"y_scale", i32},
+                                    {"out_precision", i32},
+                                    {"out_scale", i32},
+                                });
+
+  auto entry = llvm::BasicBlock::Create(*context(), "entry", function);
+  ir_builder()->SetInsertPoint(entry);
+
+  // reuse add function after negating y_value. i.e
+  //   add(x_value, x_precision, x_scale, -y_value, y_precision, y_scale,
+  //       out_precision, out_scale)
+  std::vector<llvm::Value*> args;
+  int i = 0;
+  for (auto& in_arg : function->args()) {
+    if (i == 3) {
+      auto y_neg_value = ir_builder()->CreateNeg(&in_arg);
+      args.push_back(y_neg_value);
+    } else {
+      args.push_back(&in_arg);
+    }
+    ++i;
+  }
+  auto value =
+      
ir_builder()->CreateCall(module()->getFunction("add_decimal128_decimal128"), 
args);
+
+  // store result to out
+  ir_builder()->CreateRet(value);
+  return Status::OK();
+}
+
 Status DecimalIR::AddFunctions(Engine* engine) {
   auto decimal_ir = std::make_shared<DecimalIR>(engine);
 
@@ -317,7 +363,10 @@ Status DecimalIR::AddFunctions(Engine* engine) {
   decimal_ir->InitializeIntrinsics();
 
   // build "add"
-  return decimal_ir->BuildAdd();
+  ARROW_RETURN_NOT_OK(decimal_ir->BuildAdd());
+
+  // build "subtract"
+  return decimal_ir->BuildSubtract();
 }
 
 // Do an bitwise-or of all the overflow bits.
diff --git a/cpp/src/gandiva/decimal_ir.h b/cpp/src/gandiva/decimal_ir.h
index fae762c..fb9fe70 100644
--- a/cpp/src/gandiva/decimal_ir.h
+++ b/cpp/src/gandiva/decimal_ir.h
@@ -143,6 +143,9 @@ class DecimalIR : public FunctionIRBuilder {
   // Build the function for adding decimals.
   Status BuildAdd();
 
+  // Build the function for decimal subtraction.
+  Status BuildSubtract();
+
   // Add a trace in IR code.
   void AddTrace(const std::string& fmt, std::vector<llvm::Value*> args);
 
diff --git a/cpp/src/gandiva/function_registry_arithmetic.cc 
b/cpp/src/gandiva/function_registry_arithmetic.cc
index c5a798c..0a2ac93 100644
--- a/cpp/src/gandiva/function_registry_arithmetic.cc
+++ b/cpp/src/gandiva/function_registry_arithmetic.cc
@@ -58,6 +58,7 @@ std::vector<NativeFunction> GetArithmeticFunctionRegistry() {
       BINARY_GENERIC_SAFE_NULL_IF_NULL(mod, int64, int64, int64),
 
       BINARY_SYMMETRIC_SAFE_NULL_IF_NULL(add, decimal128),
+      BINARY_SYMMETRIC_SAFE_NULL_IF_NULL(subtract, decimal128),
 
       BINARY_RELATIONAL_BOOL_FN(equal),
       BINARY_RELATIONAL_BOOL_FN(not_equal),
diff --git a/cpp/src/gandiva/precompiled/decimal_ops.cc 
b/cpp/src/gandiva/precompiled/decimal_ops.cc
index 99231fe..887f42d 100644
--- a/cpp/src/gandiva/precompiled/decimal_ops.cc
+++ b/cpp/src/gandiva/precompiled/decimal_ops.cc
@@ -221,5 +221,10 @@ BasicDecimal128 Add(const BasicDecimalScalar128& x, const 
BasicDecimalScalar128&
   }
 }
 
+BasicDecimal128 Subtract(const BasicDecimalScalar128& x, const 
BasicDecimalScalar128& y,
+                         int32_t out_precision, int32_t out_scale) {
+  return Add(x, {-y.value(), y.precision(), y.scale()}, out_precision, 
out_scale);
+}
+
 }  // namespace decimalops
 }  // namespace gandiva
diff --git a/cpp/src/gandiva/precompiled/decimal_ops.h 
b/cpp/src/gandiva/precompiled/decimal_ops.h
index 1e202b8..5a6c94b 100644
--- a/cpp/src/gandiva/precompiled/decimal_ops.h
+++ b/cpp/src/gandiva/precompiled/decimal_ops.h
@@ -30,5 +30,10 @@ namespace decimalops {
 arrow::BasicDecimal128 Add(const BasicDecimalScalar128& x, const 
BasicDecimalScalar128& y,
                            int32_t out_precision, int32_t out_scale);
 
+/// Subtract 'y' from 'x', and return the result.
+arrow::BasicDecimal128 Subtract(const BasicDecimalScalar128& x,
+                                const BasicDecimalScalar128& y, int32_t 
out_precision,
+                                int32_t out_scale);
+
 }  // namespace decimalops
 }  // namespace gandiva
diff --git a/cpp/src/gandiva/precompiled/decimal_ops_test.cc 
b/cpp/src/gandiva/precompiled/decimal_ops_test.cc
index e16f202..ef2c402 100644
--- a/cpp/src/gandiva/precompiled/decimal_ops_test.cc
+++ b/cpp/src/gandiva/precompiled/decimal_ops_test.cc
@@ -29,8 +29,18 @@ namespace gandiva {
 
 class TestDecimalSql : public ::testing::Test {
  protected:
-  static void AddAndVerify(const DecimalScalar128& x, const DecimalScalar128& 
y,
-                           const DecimalScalar128& expected);
+  static void Verify(DecimalTypeUtil::Op op, const DecimalScalar128& x,
+                     const DecimalScalar128& y, const DecimalScalar128& 
expected);
+
+  void AddAndVerify(const DecimalScalar128& x, const DecimalScalar128& y,
+                    const DecimalScalar128& expected) {
+    return Verify(DecimalTypeUtil::kOpAdd, x, y, expected);
+  }
+
+  void SubtractAndVerify(const DecimalScalar128& x, const DecimalScalar128& y,
+                         const DecimalScalar128& expected) {
+    return Verify(DecimalTypeUtil::kOpSubtract, x, y, expected);
+  }
 };
 
 #define EXPECT_DECIMAL_EQ(x, y, expected, actual)                              
      \
@@ -38,15 +48,28 @@ class TestDecimalSql : public ::testing::Test {
                               << " expected : " << expected.ToString() << " 
actual " \
                               << actual.ToString()
 
-void TestDecimalSql::AddAndVerify(const DecimalScalar128& x, const 
DecimalScalar128& y,
-                                  const DecimalScalar128& expected) {
+void TestDecimalSql::Verify(DecimalTypeUtil::Op op, const DecimalScalar128& x,
+                            const DecimalScalar128& y, const DecimalScalar128& 
expected) {
   auto t1 = std::make_shared<arrow::Decimal128Type>(x.precision(), x.scale());
   auto t2 = std::make_shared<arrow::Decimal128Type>(y.precision(), y.scale());
 
   Decimal128TypePtr out_type;
-  EXPECT_OK(DecimalTypeUtil::GetResultType(DecimalTypeUtil::kOpAdd, {t1, t2}, 
&out_type));
+  EXPECT_OK(DecimalTypeUtil::GetResultType(op, {t1, t2}, &out_type));
+
+  arrow::BasicDecimal128 out_value;
+  switch (op) {
+    case DecimalTypeUtil::kOpAdd:
+      out_value = decimalops::Add(x, y, out_type->precision(), 
out_type->scale());
+      break;
+
+    case DecimalTypeUtil::kOpSubtract:
+      out_value = decimalops::Subtract(x, y, out_type->precision(), 
out_type->scale());
+      break;
 
-  auto out_value = decimalops::Add(x, y, out_type->precision(), 
out_type->scale());
+    default:
+      // not implemented.
+      ASSERT_FALSE(true);
+  }
   EXPECT_DECIMAL_EQ(
       x, y, expected,
       DecimalScalar128(out_value, out_type->precision(), out_type->scale()));
@@ -74,4 +97,28 @@ TEST_F(TestDecimalSql, Add) {
                DecimalScalar128{"-99999999999999999999999999999990000010", 38, 
6});
 }
 
+TEST_F(TestDecimalSql, Subtract) {
+  // fast-path
+  SubtractAndVerify(DecimalScalar128{"201", 30, 3},    // x
+                    DecimalScalar128{"301", 30, 3},    // y
+                    DecimalScalar128{"-100", 31, 3});  // expected
+
+  // max precision
+  SubtractAndVerify(
+      DecimalScalar128{"09999999999999999999999999999999000000", 38, 5},  // x
+      DecimalScalar128{"100", 38, 7},                                     // y
+      DecimalScalar128{"99999999999999999999999999999989999990", 38, 6});
+
+  // Both -ve
+  SubtractAndVerify(DecimalScalar128{"-201", 30, 3},   // x
+                    DecimalScalar128{"-301", 30, 2},   // y
+                    DecimalScalar128{"2809", 32, 3});  // expected
+
+  // -ve and max precision
+  SubtractAndVerify(
+      DecimalScalar128{"-09999999999999999999999999999999000000", 38, 5},  // x
+      DecimalScalar128{"-100", 38, 7},                                     // y
+      DecimalScalar128{"-99999999999999999999999999999989999990", 38, 6});
+}
+
 }  // namespace gandiva
diff --git a/cpp/src/gandiva/tests/decimal_single_test.cc 
b/cpp/src/gandiva/tests/decimal_single_test.cc
index 776ef6e..a83137f 100644
--- a/cpp/src/gandiva/tests/decimal_single_test.cc
+++ b/cpp/src/gandiva/tests/decimal_single_test.cc
@@ -31,9 +31,10 @@ using arrow::Decimal128;
 
 namespace gandiva {
 
-#define EXPECT_DECIMAL_SUM_EQUALS(x, y, expected, actual)                  \
-  EXPECT_EQ(expected, actual) << (x).ToString() << " + " << (y).ToString() \
-                              << " expected : " << (expected).ToString()   \
+#define EXPECT_DECIMAL_RESULT(op, x, y, expected, actual)                      
          \
+  EXPECT_EQ(expected, actual) << op << " (" << (x).ToString() << "),(" << 
(y).ToString() \
+                              << ")"                                           
          \
+                              << " expected : " << (expected).ToString()       
          \
                               << " actual : " << (actual).ToString();
 
 DecimalScalar128 decimal_literal(const char* value, int precision, int scale) {
@@ -46,8 +47,19 @@ class TestDecimalOps : public ::testing::Test {
   void SetUp() { pool_ = arrow::default_memory_pool(); }
 
   ArrayPtr MakeDecimalVector(const DecimalScalar128& in);
+
+  void Verify(DecimalTypeUtil::Op, const std::string& function, const 
DecimalScalar128& x,
+              const DecimalScalar128& y, const DecimalScalar128& expected);
+
   void AddAndVerify(const DecimalScalar128& x, const DecimalScalar128& y,
-                    const DecimalScalar128& expected);
+                    const DecimalScalar128& expected) {
+    Verify(DecimalTypeUtil::kOpAdd, "add", x, y, expected);
+  }
+
+  void SubtractAndVerify(const DecimalScalar128& x, const DecimalScalar128& y,
+                         const DecimalScalar128& expected) {
+    Verify(DecimalTypeUtil::kOpSubtract, "subtract", x, y, expected);
+  }
 
  protected:
   arrow::MemoryPool* pool_;
@@ -62,8 +74,9 @@ ArrayPtr TestDecimalOps::MakeDecimalVector(const 
DecimalScalar128& in) {
   return MakeArrowArrayDecimal(decimal_type, {decimal_value}, {true});
 }
 
-void TestDecimalOps::AddAndVerify(const DecimalScalar128& x, const 
DecimalScalar128& y,
-                                  const DecimalScalar128& expected) {
+void TestDecimalOps::Verify(DecimalTypeUtil::Op op, const std::string& 
function,
+                            const DecimalScalar128& x, const DecimalScalar128& 
y,
+                            const DecimalScalar128& expected) {
   auto x_type = std::make_shared<arrow::Decimal128Type>(x.precision(), 
x.scale());
   auto y_type = std::make_shared<arrow::Decimal128Type>(y.precision(), 
y.scale());
   auto field_x = field("x", x_type);
@@ -71,15 +84,14 @@ void TestDecimalOps::AddAndVerify(const DecimalScalar128& 
x, const DecimalScalar
   auto schema = arrow::schema({field_x, field_y});
 
   Decimal128TypePtr output_type;
-  auto status = DecimalTypeUtil::GetResultType(DecimalTypeUtil::kOpAdd, 
{x_type, y_type},
-                                               &output_type);
+  auto status = DecimalTypeUtil::GetResultType(op, {x_type, y_type}, 
&output_type);
   EXPECT_OK(status);
 
   // output fields
   auto res = field("res", output_type);
 
-  // build expression : x + y
-  auto expr = TreeExprBuilder::MakeExpression("add", {field_x, field_y}, res);
+  // build expression : x op y
+  auto expr = TreeExprBuilder::MakeExpression(function, {field_x, field_y}, 
res);
 
   // Build a projector for the expression.
   std::shared_ptr<Projector> projector;
@@ -106,7 +118,7 @@ void TestDecimalOps::AddAndVerify(const DecimalScalar128& 
x, const DecimalScalar
   std::string value_string = out_value.ToString(0);
   DecimalScalar128 actual{value_string, dtype->precision(), dtype->scale()};
 
-  EXPECT_DECIMAL_SUM_EQUALS(x, y, expected, actual);
+  EXPECT_DECIMAL_RESULT(function, x, y, expected, actual);
 }
 
 TEST_F(TestDecimalOps, TestAdd) {
@@ -221,4 +233,24 @@ TEST_F(TestDecimalOps, TestAdd) {
                decimal_literal("-10000992", 38, 7),  // y
                decimal_literal("-2001098", 38, 6));
 }
+
+// subtract is a wrapper over add. so, minimal tests are sufficient.
+TEST_F(TestDecimalOps, TestSubtract) {
+  // fast-path
+  SubtractAndVerify(decimal_literal("201", 30, 3),    // x
+                    decimal_literal("301", 30, 3),    // y
+                    decimal_literal("-100", 31, 3));  // expected
+
+  // max precision
+  SubtractAndVerify(
+      decimal_literal("09999999999999999999999999999999000000", 38, 5),  // x
+      decimal_literal("100", 38, 7),                                     // y
+      decimal_literal("99999999999999999999999999999989999990", 38, 6));
+
+  // Mix of +ve and -ve
+  SubtractAndVerify(decimal_literal("-201", 30, 3),    // x
+                    decimal_literal("301", 30, 2),     // y
+                    decimal_literal("-3211", 32, 3));  // expected
+}
+
 }  // namespace gandiva

Reply via email to