This is an automated email from the ASF dual-hosted git repository.
tanner pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/calcite.git
The following commit(s) were added to refs/heads/main by this push:
new a651ea6347 [CALCITE-5640] Add SAFE_ADD function (enabled in BigQuery
library)
a651ea6347 is described below
commit a651ea6347f1ef87a5bf4de14b0c7e778ba393b5
Author: Tanner Clary <[email protected]>
AuthorDate: Wed Aug 16 09:12:57 2023 -0700
[CALCITE-5640] Add SAFE_ADD function (enabled in BigQuery library)
---
babel/src/test/resources/sql/big-query.iq | 68 +++++++++++++++++++
.../calcite/adapter/enumerable/RexImpTable.java | 10 +--
.../org/apache/calcite/runtime/SqlFunctions.java | 64 ++++++++++++++++--
.../calcite/sql/fun/SqlLibraryOperators.java | 9 +++
.../org/apache/calcite/sql/type/ReturnTypes.java | 9 +++
site/_docs/reference.md | 1 +
.../org/apache/calcite/test/SqlOperatorTest.java | 78 ++++++++++++++++++++++
7 files changed, 229 insertions(+), 10 deletions(-)
diff --git a/babel/src/test/resources/sql/big-query.iq
b/babel/src/test/resources/sql/big-query.iq
index 82a37386e9..1e2552a699 100755
--- a/babel/src/test/resources/sql/big-query.iq
+++ b/babel/src/test/resources/sql/big-query.iq
@@ -600,6 +600,74 @@ FROM t;
!ok
!}
+#####################################################################
+# SAFE_ADD
+#
+# SAFE_ADD(value1, value2)
+#
+# Equivalent to the addition operator (+), but returns NULL if
overflow/underflow occurs.
+SELECT SAFE_ADD(5, 4) as result;
++--------+
+| result |
++--------+
+| 9 |
++--------+
+(1 row)
+
+!ok
+
+# Overflow occurs if result is greater than 2^63 - 1
+SELECT SAFE_ADD(9223372036854775807, 2) as overflow_result;
++-----------------+
+| overflow_result |
++-----------------+
+| |
++-----------------+
+(1 row)
+
+!ok
+
+# Underflow occurs if result is less than -2^63
+SELECT SAFE_ADD(-9223372036854775806, -3) as underflow_result;
++------------------+
+| underflow_result |
++------------------+
+| |
++------------------+
+(1 row)
+
+!ok
+
+SELECT SAFE_ADD(CAST(1.7e308 as DOUBLE), CAST(1.7e308 as DOUBLE)) as
double_overflow;
++-----------------+
+| double_overflow |
++-----------------+
+| |
++-----------------+
+(1 row)
+
+!ok
+
+SELECT SAFE_ADD(9, cast(9.999999999999999999e75 as DECIMAL(38, 19))) as
decimal_overflow;
++------------------+
+| decimal_overflow |
++------------------+
+| |
++------------------+
+(1 row)
+
+!ok
+
+# NaN arguments should return NaN
+SELECT SAFE_ADD(CAST('NaN' AS DOUBLE), CAST(3 as BIGINT)) as NaN_result;
++------------+
+| NaN_result |
++------------+
+| NaN |
++------------+
+(1 row)
+
+!ok
#####################################################################
# SAFE_MULTIPLY
#
diff --git
a/core/src/main/java/org/apache/calcite/adapter/enumerable/RexImpTable.java
b/core/src/main/java/org/apache/calcite/adapter/enumerable/RexImpTable.java
index c23c2f8678..80c1eee8cf 100644
--- a/core/src/main/java/org/apache/calcite/adapter/enumerable/RexImpTable.java
+++ b/core/src/main/java/org/apache/calcite/adapter/enumerable/RexImpTable.java
@@ -223,6 +223,7 @@ import static
org.apache.calcite.sql.fun.SqlLibraryOperators.REVERSE;
import static org.apache.calcite.sql.fun.SqlLibraryOperators.RIGHT;
import static org.apache.calcite.sql.fun.SqlLibraryOperators.RLIKE;
import static org.apache.calcite.sql.fun.SqlLibraryOperators.RPAD;
+import static org.apache.calcite.sql.fun.SqlLibraryOperators.SAFE_ADD;
import static org.apache.calcite.sql.fun.SqlLibraryOperators.SAFE_CAST;
import static org.apache.calcite.sql.fun.SqlLibraryOperators.SAFE_MULTIPLY;
import static org.apache.calcite.sql.fun.SqlLibraryOperators.SAFE_OFFSET;
@@ -624,7 +625,8 @@ public class RexImpTable {
defineMethod(TRUNC, "struncate", NullPolicy.STRICT);
defineMethod(TRUNCATE, "struncate", NullPolicy.STRICT);
- map.put(SAFE_MULTIPLY, new SafeArithmeticImplementor());
+ map.put(SAFE_ADD, new SafeArithmeticImplementor("safeAdd"));
+ map.put(SAFE_MULTIPLY, new SafeArithmeticImplementor("safeMultiply"));
map.put(PI, new PiImplementor());
return populate2();
@@ -2395,15 +2397,15 @@ public class RexImpTable {
/** Implementor for the {@code SAFE_MULTIPLY} function. */
private static class SafeArithmeticImplementor extends MethodNameImplementor
{
- SafeArithmeticImplementor() {
- super("safeMultiply", NullPolicy.STRICT, false);
+ SafeArithmeticImplementor(String methodName) {
+ super(methodName, NullPolicy.STRICT, false);
}
@Override Expression implementSafe(final RexToLixTranslator translator,
final RexCall call, final List<Expression> argValueList) {
Expression arg0 = convertType(argValueList.get(0), call.operands.get(0));
Expression arg1 = convertType(argValueList.get(1), call.operands.get(1));
- return Expressions.call(SqlFunctions.class, "safeMultiply", arg0, arg1);
+ return Expressions.call(SqlFunctions.class, methodName, arg0, arg1);
}
// Because BigQuery treats all int types as aliases for BIGINT (Java's
long)
diff --git a/core/src/main/java/org/apache/calcite/runtime/SqlFunctions.java
b/core/src/main/java/org/apache/calcite/runtime/SqlFunctions.java
index 9374b31054..129b45fa7a 100644
--- a/core/src/main/java/org/apache/calcite/runtime/SqlFunctions.java
+++ b/core/src/main/java/org/apache/calcite/runtime/SqlFunctions.java
@@ -1735,6 +1735,61 @@ public class SqlFunctions {
throw notArithmetic("*", b0, b1);
}
+ /** SQL <code>SAFE_ADD</code> function applied to long values. */
+ public static @Nullable Long safeAdd(long b0, long b1) {
+ try {
+ return Math.addExact(b0, b1);
+ } catch (ArithmeticException e) {
+ return null;
+ }
+ }
+
+ /** SQL <code>SAFE_ADD</code> function applied to long and BigDecimal
values. */
+ public static @Nullable BigDecimal safeAdd(long b0, BigDecimal b1) {
+ BigDecimal ans = BigDecimal.valueOf(b0).add(b1);
+ return safeDecimal(ans) ? ans : null;
+ }
+
+ /** SQL <code>SAFE_ADD</code> function applied to BigDecimal and long
values. */
+ public static @Nullable BigDecimal safeAdd(BigDecimal b0, long b1) {
+ return safeAdd(b1, b0);
+ }
+
+ /** SQL <code>SAFE_ADD</code> function applied to BigDecimal values. */
+ public static @Nullable BigDecimal safeAdd(BigDecimal b0, BigDecimal b1) {
+ BigDecimal ans = b0.add(b1);
+ return safeDecimal(ans) ? ans : null;
+ }
+
+ /** SQL <code>SAFE_ADD</code> function applied to double and long values. */
+ public static @Nullable Double safeAdd(double b0, long b1) {
+ double ans = b0 + b1;
+ return safeDouble(ans) || !Double.isFinite(b0) ? ans : null;
+ }
+
+ /** SQL <code>SAFE_ADD</code> function applied to long and double values. */
+ public static @Nullable Double safeAdd(long b0, double b1) {
+ return safeAdd(b1, b0);
+ }
+
+ /** SQL <code>SAFE_ADD</code> function applied to double and BigDecimal
values. */
+ public static @Nullable Double safeAdd(double b0, BigDecimal b1) {
+ double ans = b0 + b1.doubleValue();
+ return safeDouble(ans) || !Double.isFinite(b0) ? ans : null;
+ }
+
+ /** SQL <code>SAFE_ADD</code> function applied to BigDecimal and double
values. */
+ public static @Nullable Double safeAdd(BigDecimal b0, double b1) {
+ return safeAdd(b1, b0);
+ }
+
+ /** SQL <code>SAFE_ADD</code> function applied to double values. */
+ public static @Nullable Double safeAdd(double b0, double b1) {
+ double ans = b0 + b1;
+ boolean isFinite = Double.isFinite(b0) && Double.isFinite(b1);
+ return safeDouble(ans) || !isFinite ? ans : null;
+ }
+
/** SQL <code>SAFE_MULTIPLY</code> function applied to long values. */
public static @Nullable Long safeMultiply(long b0, long b1) {
try {
@@ -1752,8 +1807,7 @@ public class SqlFunctions {
/** SQL <code>SAFE_MULTIPLY</code> function applied to BigDecimal and long
values. */
public static @Nullable BigDecimal safeMultiply(BigDecimal b0, long b1) {
- BigDecimal ans = b0.multiply(BigDecimal.valueOf(b1));
- return safeDecimal(ans) ? ans : null;
+ return safeMultiply(b1, b0);
}
/** SQL <code>SAFE_MULTIPLY</code> function applied to BigDecimal values. */
@@ -1770,8 +1824,7 @@ public class SqlFunctions {
/** SQL <code>SAFE_MULTIPLY</code> function applied to long and double
values. */
public static @Nullable Double safeMultiply(long b0, double b1) {
- double ans = b0 * b1;
- return safeDouble(ans) || !Double.isFinite(b1) ? ans : null;
+ return safeMultiply(b1, b0);
}
/** SQL <code>SAFE_MULTIPLY</code> function applied to double and BigDecimal
values. */
@@ -1782,8 +1835,7 @@ public class SqlFunctions {
/** SQL <code>SAFE_MULTIPLY</code> function applied to BigDecimal and double
values. */
public static @Nullable Double safeMultiply(BigDecimal b0, double b1) {
- double ans = b0.doubleValue() * b1;
- return safeDouble(ans) || !Double.isFinite(b1) ? ans : null;
+ return safeMultiply(b1, b0);
}
/** SQL <code>SAFE_MULTIPLY</code> function applied to double values. */
diff --git
a/core/src/main/java/org/apache/calcite/sql/fun/SqlLibraryOperators.java
b/core/src/main/java/org/apache/calcite/sql/fun/SqlLibraryOperators.java
index 6976d52cf8..0b782d3458 100644
--- a/core/src/main/java/org/apache/calcite/sql/fun/SqlLibraryOperators.java
+++ b/core/src/main/java/org/apache/calcite/sql/fun/SqlLibraryOperators.java
@@ -1682,6 +1682,15 @@ public abstract class SqlLibraryOperators {
OperandTypes.family(SqlTypeFamily.TIMESTAMP, SqlTypeFamily.TIMESTAMP,
SqlTypeFamily.ANY));
+ /** The "SAFE_ADD(numeric1, numeric2)" function; equivalent to the {@code +}
operator but
+ * returns null if overflow occurs. */
+ @LibraryOperator(libraries = {BIG_QUERY})
+ public static final SqlFunction SAFE_ADD =
+ SqlBasicFunction.create("SAFE_ADD",
+ ReturnTypes.SUM_FORCE_NULLABLE,
+ OperandTypes.NUMERIC_NUMERIC,
+ SqlFunctionCategory.NUMERIC);
+
/** The "SAFE_MULTIPLY(numeric1, numeric2)" function; equivalent to the
{@code *} operator but
* returns null if overflow occurs. */
@LibraryOperator(libraries = {BIG_QUERY})
diff --git a/core/src/main/java/org/apache/calcite/sql/type/ReturnTypes.java
b/core/src/main/java/org/apache/calcite/sql/type/ReturnTypes.java
index 8bf20706f8..acdd678a9d 100644
--- a/core/src/main/java/org/apache/calcite/sql/type/ReturnTypes.java
+++ b/core/src/main/java/org/apache/calcite/sql/type/ReturnTypes.java
@@ -843,6 +843,15 @@ public abstract class ReturnTypes {
public static final SqlReturnTypeInference DECIMAL_SUM_NULLABLE =
DECIMAL_SUM.andThen(SqlTypeTransforms.TO_NULLABLE);
+ /**
+ * Same as {@link #DECIMAL_SUM_NULLABLE} but returns with nullability if any
of
+ * the operands is nullable or the operation results in overflow by using
+ * {@link org.apache.calcite.sql.type.SqlTypeTransforms#FORCE_NULLABLE}.
Also handles
+ * addition for integers, not just decimals.
+ */
+ public static final SqlReturnTypeInference SUM_FORCE_NULLABLE =
+
DECIMAL_SUM_NULLABLE.orElse(LEAST_RESTRICTIVE).andThen(SqlTypeTransforms.FORCE_NULLABLE);
+
/**
* Type-inference strategy whereby the result type of a call is
* {@link #DECIMAL_SUM_NULLABLE} with a fallback to {@link
#LEAST_RESTRICTIVE}
diff --git a/site/_docs/reference.md b/site/_docs/reference.md
index 7daa15b7b2..791fce5b4f 100644
--- a/site/_docs/reference.md
+++ b/site/_docs/reference.md
@@ -2788,6 +2788,7 @@ BigQuery's type system uses confusingly different names
for types and functions:
| h s | string1 NOT RLIKE string2 | Whether *string1* does
not match regex pattern *string2* (similar to `NOT LIKE`, but uses Java regex)
| b o | RPAD(string, length[, pattern ]) | Returns a string or
bytes value that consists of *string* appended to *length* with *pattern*
| b o | RTRIM(string) | Returns *string* with
all blanks removed from the end
+| b | SAFE_ADD(numeric1, numeric2) | Returns *numeric1* +
*numeric2*, or NULL on overflow
| b | SAFE_CAST(value AS type) | Converts *value* to
*type*, returning NULL if conversion fails
| b | SAFE_MULTIPLY(numeric1, numeric2) | Returns *numeric1* *
*numeric2*, or NULL on overflow
| b | SAFE_OFFSET(index) | Similar to `OFFSET`
except null is returned if *index* is out of bounds
diff --git a/testkit/src/main/java/org/apache/calcite/test/SqlOperatorTest.java
b/testkit/src/main/java/org/apache/calcite/test/SqlOperatorTest.java
index 1edd551984..5bf8ae0aeb 100644
--- a/testkit/src/main/java/org/apache/calcite/test/SqlOperatorTest.java
+++ b/testkit/src/main/java/org/apache/calcite/test/SqlOperatorTest.java
@@ -7252,6 +7252,84 @@ public class SqlOperatorTest {
f.checkNull("truncate(cast(null as double))");
}
+ @Test void testSafeAddFunc() {
+ final SqlOperatorFixture f0 =
fixture().setFor(SqlLibraryOperators.SAFE_ADD);
+ f0.checkFails("^safe_add(2, 3)^",
+ "No match found for function signature "
+ + "SAFE_ADD\\(<NUMERIC>, <NUMERIC>\\)", false);
+ final SqlOperatorFixture f = f0.withLibrary(SqlLibrary.BIG_QUERY);
+ // Basic test for each of the 9 2-permutations of BIGINT, DECIMAL, and
FLOAT
+ f.checkScalar("safe_add(cast(20 as bigint), cast(20 as bigint))",
+ "40", "BIGINT");
+ f.checkScalar("safe_add(cast(20 as bigint), cast(1.2345 as decimal(5,4)))",
+ "21.2345", "DECIMAL(19, 4)");
+ f.checkScalar("safe_add(cast(1.2345 as decimal(5,4)), cast(20 as bigint))",
+ "21.2345", "DECIMAL(19, 4)");
+ f.checkScalar("safe_add(cast(1.2345 as decimal(5,4)), "
+ + "cast(2.0 as decimal(2, 1)))", "3.2345", "DECIMAL(6, 4)");
+ f.checkScalar("safe_add(cast(3 as double), cast(3 as bigint))",
+ "6.0", "DOUBLE");
+ f.checkScalar("safe_add(cast(3 as bigint), cast(3 as double))",
+ "6.0", "DOUBLE");
+ f.checkScalar("safe_add(cast(3 as double), cast(1.2345 as decimal(5, 4)))",
+ "4.2345", "DOUBLE");
+ f.checkScalar("safe_add(cast(1.2345 as decimal(5, 4)), cast(3 as double))",
+ "4.2345", "DOUBLE");
+ f.checkScalar("safe_add(cast(3 as double), cast(3 as double))",
+ "6.0", "DOUBLE");
+ // Tests for + and - Infinity
+ f.checkScalar("safe_add(cast('Infinity' as double), cast(3 as double))",
+ "Infinity", "DOUBLE");
+ f.checkScalar("safe_add(cast('-Infinity' as double), cast(3 as double))",
+ "-Infinity", "DOUBLE");
+ f.checkScalar("safe_add(cast('-Infinity' as double), "
+ + "cast('Infinity' as double))", "NaN", "DOUBLE");
+ // Tests for NaN
+ f.checkScalar("safe_add(cast('NaN' as double), cast(3 as bigint))",
+ "NaN", "DOUBLE");
+ f.checkScalar("safe_add(cast('NaN' as double), cast(1.23 as decimal(3,
2)))",
+ "NaN", "DOUBLE");
+ f.checkScalar("safe_add(cast('NaN' as double), cast('Infinity' as
double))",
+ "NaN", "DOUBLE");
+ f.checkScalar("safe_add(cast(3 as bigint), cast('NaN' as double))",
+ "NaN", "DOUBLE");
+ f.checkScalar("safe_add(cast(1.23 as decimal(3, 2)), cast('NaN' as
double))",
+ "NaN", "DOUBLE");
+ // Overflow test for each pairing
+ f.checkNull("safe_add(cast(20 as bigint), "
+ + "cast(9223372036854775807 as bigint))");
+ f.checkNull("safe_add(cast(-20 as bigint), "
+ + "cast(-9223372036854775807 as bigint))");
+ f.checkNull("safe_add(9, cast(9.999999999999999999e75 as DECIMAL(38,
19)))");
+ f.checkNull("safe_add(-9, cast(-9.999999999999999999e75 as DECIMAL(38,
19)))");
+ f.checkNull("safe_add(cast(9.999999999999999999e75 as DECIMAL(38, 19)),
9)");
+ f.checkNull("safe_add(cast(-9.999999999999999999e75 as DECIMAL(38, 19)),
-9)");
+ f.checkNull("safe_add(cast(9.9e75 as DECIMAL(76, 0)), "
+ + "cast(9.9e75 as DECIMAL(76, 0)))");
+ f.checkNull("safe_add(cast(-9.9e75 as DECIMAL(76, 0)), "
+ + "cast(-9.9e75 as DECIMAL(76, 0)))");
+ f.checkNull("safe_add(cast(1.7976931348623157e308 as double), "
+ + "cast(9.9e7 as decimal(76, 0)))");
+ f.checkNull("safe_add(cast(-1.7976931348623157e308 as double), "
+ + "cast(-9.9e7 as decimal(76, 0)))");
+ f.checkNull("safe_add(cast(9.9e7 as decimal(76, 0)), "
+ + "cast(1.7976931348623157e308 as double))");
+ f.checkNull("safe_add(cast(-9.9e7 as decimal(76, 0)), "
+ + "cast(-1.7976931348623157e308 as double))");
+ f.checkNull("safe_add(cast(1.7976931348623157e308 as double), cast(3 as
bigint))");
+ f.checkNull("safe_add(cast(-1.7976931348623157e308 as double), "
+ + "cast(-3 as bigint))");
+ f.checkNull("safe_add(cast(3 as bigint), cast(1.7976931348623157e308 as
double))");
+ f.checkNull("safe_add(cast(-3 as bigint), "
+ + "cast(-1.7976931348623157e308 as double))");
+ f.checkNull("safe_add(cast(3 as double), cast(1.7976931348623157e308 as
double))");
+ f.checkNull("safe_add(cast(-3 as double), "
+ + "cast(-1.7976931348623157e308 as double))");
+ // Check that null argument retuns null
+ f.checkNull("safe_add(cast(null as double), cast(3 as bigint))");
+ f.checkNull("safe_add(cast(3 as double), cast(null as bigint))");
+ }
+
@Test void testSafeMultiplyFunc() {
final SqlOperatorFixture f0 =
fixture().setFor(SqlLibraryOperators.SAFE_MULTIPLY);
f0.checkFails("^safe_multiply(2, 3)^",