This is an automated email from the ASF dual-hosted git repository.
apitrou pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/master by this push:
new 49c70cc ARROW-4204: [Gandiva] add support for decimal subtract
49c70cc is described below
commit 49c70ccfa70f041e7830628ef6a57f1fbdf131a9
Author: Pindikura Ravindra <[email protected]>
AuthorDate: Thu Feb 14 17:27:49 2019 +0100
ARROW-4204: [Gandiva] add support for decimal subtract
Author: Pindikura Ravindra <[email protected]>
Closes #3636 from pravindra/dsub and squashes the following commits:
ee60c009 <Pindikura Ravindra> ARROW-4204: add support for decimal subtract
---
cpp/src/gandiva/decimal_ir.cc | 51 ++++++++++++++++++++-
cpp/src/gandiva/decimal_ir.h | 3 ++
cpp/src/gandiva/function_registry_arithmetic.cc | 1 +
cpp/src/gandiva/precompiled/decimal_ops.cc | 5 +++
cpp/src/gandiva/precompiled/decimal_ops.h | 5 +++
cpp/src/gandiva/precompiled/decimal_ops_test.cc | 59 ++++++++++++++++++++++---
cpp/src/gandiva/tests/decimal_single_test.cc | 54 +++++++++++++++++-----
7 files changed, 160 insertions(+), 18 deletions(-)
diff --git a/cpp/src/gandiva/decimal_ir.cc b/cpp/src/gandiva/decimal_ir.cc
index d10158a..f51f512 100644
--- a/cpp/src/gandiva/decimal_ir.cc
+++ b/cpp/src/gandiva/decimal_ir.cc
@@ -307,6 +307,52 @@ Status DecimalIR::BuildAdd() {
return Status::OK();
}
+Status DecimalIR::BuildSubtract() {
+ // Create fn prototype :
+ // int128_t
+ // subtract_decimal128_decimal128(int128_t x_value, int32_t x_precision,
int32_t
+ // x_scale,
+ // int128_t y_value, int32_t y_precision, int32_t
y_scale
+ // int32_t out_precision, int32_t out_scale)
+ auto i32 = types()->i32_type();
+ auto i128 = types()->i128_type();
+ auto function = BuildFunction("subtract_decimal128_decimal128", i128,
+ {
+ {"x_value", i128},
+ {"x_precision", i32},
+ {"x_scale", i32},
+ {"y_value", i128},
+ {"y_precision", i32},
+ {"y_scale", i32},
+ {"out_precision", i32},
+ {"out_scale", i32},
+ });
+
+ auto entry = llvm::BasicBlock::Create(*context(), "entry", function);
+ ir_builder()->SetInsertPoint(entry);
+
+ // reuse add function after negating y_value. i.e
+ // add(x_value, x_precision, x_scale, -y_value, y_precision, y_scale,
+ // out_precision, out_scale)
+ std::vector<llvm::Value*> args;
+ int i = 0;
+ for (auto& in_arg : function->args()) {
+ if (i == 3) {
+ auto y_neg_value = ir_builder()->CreateNeg(&in_arg);
+ args.push_back(y_neg_value);
+ } else {
+ args.push_back(&in_arg);
+ }
+ ++i;
+ }
+ auto value =
+
ir_builder()->CreateCall(module()->getFunction("add_decimal128_decimal128"),
args);
+
+ // store result to out
+ ir_builder()->CreateRet(value);
+ return Status::OK();
+}
+
Status DecimalIR::AddFunctions(Engine* engine) {
auto decimal_ir = std::make_shared<DecimalIR>(engine);
@@ -317,7 +363,10 @@ Status DecimalIR::AddFunctions(Engine* engine) {
decimal_ir->InitializeIntrinsics();
// build "add"
- return decimal_ir->BuildAdd();
+ ARROW_RETURN_NOT_OK(decimal_ir->BuildAdd());
+
+ // build "subtract"
+ return decimal_ir->BuildSubtract();
}
// Do an bitwise-or of all the overflow bits.
diff --git a/cpp/src/gandiva/decimal_ir.h b/cpp/src/gandiva/decimal_ir.h
index fae762c..fb9fe70 100644
--- a/cpp/src/gandiva/decimal_ir.h
+++ b/cpp/src/gandiva/decimal_ir.h
@@ -143,6 +143,9 @@ class DecimalIR : public FunctionIRBuilder {
// Build the function for adding decimals.
Status BuildAdd();
+ // Build the function for decimal subtraction.
+ Status BuildSubtract();
+
// Add a trace in IR code.
void AddTrace(const std::string& fmt, std::vector<llvm::Value*> args);
diff --git a/cpp/src/gandiva/function_registry_arithmetic.cc
b/cpp/src/gandiva/function_registry_arithmetic.cc
index c5a798c..0a2ac93 100644
--- a/cpp/src/gandiva/function_registry_arithmetic.cc
+++ b/cpp/src/gandiva/function_registry_arithmetic.cc
@@ -58,6 +58,7 @@ std::vector<NativeFunction> GetArithmeticFunctionRegistry() {
BINARY_GENERIC_SAFE_NULL_IF_NULL(mod, int64, int64, int64),
BINARY_SYMMETRIC_SAFE_NULL_IF_NULL(add, decimal128),
+ BINARY_SYMMETRIC_SAFE_NULL_IF_NULL(subtract, decimal128),
BINARY_RELATIONAL_BOOL_FN(equal),
BINARY_RELATIONAL_BOOL_FN(not_equal),
diff --git a/cpp/src/gandiva/precompiled/decimal_ops.cc
b/cpp/src/gandiva/precompiled/decimal_ops.cc
index 99231fe..887f42d 100644
--- a/cpp/src/gandiva/precompiled/decimal_ops.cc
+++ b/cpp/src/gandiva/precompiled/decimal_ops.cc
@@ -221,5 +221,10 @@ BasicDecimal128 Add(const BasicDecimalScalar128& x, const
BasicDecimalScalar128&
}
}
+BasicDecimal128 Subtract(const BasicDecimalScalar128& x, const
BasicDecimalScalar128& y,
+ int32_t out_precision, int32_t out_scale) {
+ return Add(x, {-y.value(), y.precision(), y.scale()}, out_precision,
out_scale);
+}
+
} // namespace decimalops
} // namespace gandiva
diff --git a/cpp/src/gandiva/precompiled/decimal_ops.h
b/cpp/src/gandiva/precompiled/decimal_ops.h
index 1e202b8..5a6c94b 100644
--- a/cpp/src/gandiva/precompiled/decimal_ops.h
+++ b/cpp/src/gandiva/precompiled/decimal_ops.h
@@ -30,5 +30,10 @@ namespace decimalops {
arrow::BasicDecimal128 Add(const BasicDecimalScalar128& x, const
BasicDecimalScalar128& y,
int32_t out_precision, int32_t out_scale);
+/// Subtract 'y' from 'x', and return the result.
+arrow::BasicDecimal128 Subtract(const BasicDecimalScalar128& x,
+ const BasicDecimalScalar128& y, int32_t
out_precision,
+ int32_t out_scale);
+
} // namespace decimalops
} // namespace gandiva
diff --git a/cpp/src/gandiva/precompiled/decimal_ops_test.cc
b/cpp/src/gandiva/precompiled/decimal_ops_test.cc
index e16f202..ef2c402 100644
--- a/cpp/src/gandiva/precompiled/decimal_ops_test.cc
+++ b/cpp/src/gandiva/precompiled/decimal_ops_test.cc
@@ -29,8 +29,18 @@ namespace gandiva {
class TestDecimalSql : public ::testing::Test {
protected:
- static void AddAndVerify(const DecimalScalar128& x, const DecimalScalar128&
y,
- const DecimalScalar128& expected);
+ static void Verify(DecimalTypeUtil::Op op, const DecimalScalar128& x,
+ const DecimalScalar128& y, const DecimalScalar128&
expected);
+
+ void AddAndVerify(const DecimalScalar128& x, const DecimalScalar128& y,
+ const DecimalScalar128& expected) {
+ return Verify(DecimalTypeUtil::kOpAdd, x, y, expected);
+ }
+
+ void SubtractAndVerify(const DecimalScalar128& x, const DecimalScalar128& y,
+ const DecimalScalar128& expected) {
+ return Verify(DecimalTypeUtil::kOpSubtract, x, y, expected);
+ }
};
#define EXPECT_DECIMAL_EQ(x, y, expected, actual)
\
@@ -38,15 +48,28 @@ class TestDecimalSql : public ::testing::Test {
<< " expected : " << expected.ToString() << "
actual " \
<< actual.ToString()
-void TestDecimalSql::AddAndVerify(const DecimalScalar128& x, const
DecimalScalar128& y,
- const DecimalScalar128& expected) {
+void TestDecimalSql::Verify(DecimalTypeUtil::Op op, const DecimalScalar128& x,
+ const DecimalScalar128& y, const DecimalScalar128&
expected) {
auto t1 = std::make_shared<arrow::Decimal128Type>(x.precision(), x.scale());
auto t2 = std::make_shared<arrow::Decimal128Type>(y.precision(), y.scale());
Decimal128TypePtr out_type;
- EXPECT_OK(DecimalTypeUtil::GetResultType(DecimalTypeUtil::kOpAdd, {t1, t2},
&out_type));
+ EXPECT_OK(DecimalTypeUtil::GetResultType(op, {t1, t2}, &out_type));
+
+ arrow::BasicDecimal128 out_value;
+ switch (op) {
+ case DecimalTypeUtil::kOpAdd:
+ out_value = decimalops::Add(x, y, out_type->precision(),
out_type->scale());
+ break;
+
+ case DecimalTypeUtil::kOpSubtract:
+ out_value = decimalops::Subtract(x, y, out_type->precision(),
out_type->scale());
+ break;
- auto out_value = decimalops::Add(x, y, out_type->precision(),
out_type->scale());
+ default:
+ // not implemented.
+ ASSERT_FALSE(true);
+ }
EXPECT_DECIMAL_EQ(
x, y, expected,
DecimalScalar128(out_value, out_type->precision(), out_type->scale()));
@@ -74,4 +97,28 @@ TEST_F(TestDecimalSql, Add) {
DecimalScalar128{"-99999999999999999999999999999990000010", 38,
6});
}
+TEST_F(TestDecimalSql, Subtract) {
+ // fast-path
+ SubtractAndVerify(DecimalScalar128{"201", 30, 3}, // x
+ DecimalScalar128{"301", 30, 3}, // y
+ DecimalScalar128{"-100", 31, 3}); // expected
+
+ // max precision
+ SubtractAndVerify(
+ DecimalScalar128{"09999999999999999999999999999999000000", 38, 5}, // x
+ DecimalScalar128{"100", 38, 7}, // y
+ DecimalScalar128{"99999999999999999999999999999989999990", 38, 6});
+
+ // Both -ve
+ SubtractAndVerify(DecimalScalar128{"-201", 30, 3}, // x
+ DecimalScalar128{"-301", 30, 2}, // y
+ DecimalScalar128{"2809", 32, 3}); // expected
+
+ // -ve and max precision
+ SubtractAndVerify(
+ DecimalScalar128{"-09999999999999999999999999999999000000", 38, 5}, // x
+ DecimalScalar128{"-100", 38, 7}, // y
+ DecimalScalar128{"-99999999999999999999999999999989999990", 38, 6});
+}
+
} // namespace gandiva
diff --git a/cpp/src/gandiva/tests/decimal_single_test.cc
b/cpp/src/gandiva/tests/decimal_single_test.cc
index 776ef6e..a83137f 100644
--- a/cpp/src/gandiva/tests/decimal_single_test.cc
+++ b/cpp/src/gandiva/tests/decimal_single_test.cc
@@ -31,9 +31,10 @@ using arrow::Decimal128;
namespace gandiva {
-#define EXPECT_DECIMAL_SUM_EQUALS(x, y, expected, actual) \
- EXPECT_EQ(expected, actual) << (x).ToString() << " + " << (y).ToString() \
- << " expected : " << (expected).ToString() \
+#define EXPECT_DECIMAL_RESULT(op, x, y, expected, actual)
\
+ EXPECT_EQ(expected, actual) << op << " (" << (x).ToString() << "),(" <<
(y).ToString() \
+ << ")"
\
+ << " expected : " << (expected).ToString()
\
<< " actual : " << (actual).ToString();
DecimalScalar128 decimal_literal(const char* value, int precision, int scale) {
@@ -46,8 +47,19 @@ class TestDecimalOps : public ::testing::Test {
void SetUp() { pool_ = arrow::default_memory_pool(); }
ArrayPtr MakeDecimalVector(const DecimalScalar128& in);
+
+ void Verify(DecimalTypeUtil::Op, const std::string& function, const
DecimalScalar128& x,
+ const DecimalScalar128& y, const DecimalScalar128& expected);
+
void AddAndVerify(const DecimalScalar128& x, const DecimalScalar128& y,
- const DecimalScalar128& expected);
+ const DecimalScalar128& expected) {
+ Verify(DecimalTypeUtil::kOpAdd, "add", x, y, expected);
+ }
+
+ void SubtractAndVerify(const DecimalScalar128& x, const DecimalScalar128& y,
+ const DecimalScalar128& expected) {
+ Verify(DecimalTypeUtil::kOpSubtract, "subtract", x, y, expected);
+ }
protected:
arrow::MemoryPool* pool_;
@@ -62,8 +74,9 @@ ArrayPtr TestDecimalOps::MakeDecimalVector(const
DecimalScalar128& in) {
return MakeArrowArrayDecimal(decimal_type, {decimal_value}, {true});
}
-void TestDecimalOps::AddAndVerify(const DecimalScalar128& x, const
DecimalScalar128& y,
- const DecimalScalar128& expected) {
+void TestDecimalOps::Verify(DecimalTypeUtil::Op op, const std::string&
function,
+ const DecimalScalar128& x, const DecimalScalar128&
y,
+ const DecimalScalar128& expected) {
auto x_type = std::make_shared<arrow::Decimal128Type>(x.precision(),
x.scale());
auto y_type = std::make_shared<arrow::Decimal128Type>(y.precision(),
y.scale());
auto field_x = field("x", x_type);
@@ -71,15 +84,14 @@ void TestDecimalOps::AddAndVerify(const DecimalScalar128&
x, const DecimalScalar
auto schema = arrow::schema({field_x, field_y});
Decimal128TypePtr output_type;
- auto status = DecimalTypeUtil::GetResultType(DecimalTypeUtil::kOpAdd,
{x_type, y_type},
- &output_type);
+ auto status = DecimalTypeUtil::GetResultType(op, {x_type, y_type},
&output_type);
EXPECT_OK(status);
// output fields
auto res = field("res", output_type);
- // build expression : x + y
- auto expr = TreeExprBuilder::MakeExpression("add", {field_x, field_y}, res);
+ // build expression : x op y
+ auto expr = TreeExprBuilder::MakeExpression(function, {field_x, field_y},
res);
// Build a projector for the expression.
std::shared_ptr<Projector> projector;
@@ -106,7 +118,7 @@ void TestDecimalOps::AddAndVerify(const DecimalScalar128&
x, const DecimalScalar
std::string value_string = out_value.ToString(0);
DecimalScalar128 actual{value_string, dtype->precision(), dtype->scale()};
- EXPECT_DECIMAL_SUM_EQUALS(x, y, expected, actual);
+ EXPECT_DECIMAL_RESULT(function, x, y, expected, actual);
}
TEST_F(TestDecimalOps, TestAdd) {
@@ -221,4 +233,24 @@ TEST_F(TestDecimalOps, TestAdd) {
decimal_literal("-10000992", 38, 7), // y
decimal_literal("-2001098", 38, 6));
}
+
+// subtract is a wrapper over add. so, minimal tests are sufficient.
+TEST_F(TestDecimalOps, TestSubtract) {
+ // fast-path
+ SubtractAndVerify(decimal_literal("201", 30, 3), // x
+ decimal_literal("301", 30, 3), // y
+ decimal_literal("-100", 31, 3)); // expected
+
+ // max precision
+ SubtractAndVerify(
+ decimal_literal("09999999999999999999999999999999000000", 38, 5), // x
+ decimal_literal("100", 38, 7), // y
+ decimal_literal("99999999999999999999999999999989999990", 38, 6));
+
+ // Mix of +ve and -ve
+ SubtractAndVerify(decimal_literal("-201", 30, 3), // x
+ decimal_literal("301", 30, 2), // y
+ decimal_literal("-3211", 32, 3)); // expected
+}
+
} // namespace gandiva