This is an automated email from the ASF dual-hosted git repository.

tanner pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/calcite.git


The following commit(s) were added to refs/heads/main by this push:
     new a651ea6347 [CALCITE-5640] Add SAFE_ADD function (enabled in BigQuery 
library)
a651ea6347 is described below

commit a651ea6347f1ef87a5bf4de14b0c7e778ba393b5
Author: Tanner Clary <[email protected]>
AuthorDate: Wed Aug 16 09:12:57 2023 -0700

    [CALCITE-5640] Add SAFE_ADD function (enabled in BigQuery library)
---
 babel/src/test/resources/sql/big-query.iq          | 68 +++++++++++++++++++
 .../calcite/adapter/enumerable/RexImpTable.java    | 10 +--
 .../org/apache/calcite/runtime/SqlFunctions.java   | 64 ++++++++++++++++--
 .../calcite/sql/fun/SqlLibraryOperators.java       |  9 +++
 .../org/apache/calcite/sql/type/ReturnTypes.java   |  9 +++
 site/_docs/reference.md                            |  1 +
 .../org/apache/calcite/test/SqlOperatorTest.java   | 78 ++++++++++++++++++++++
 7 files changed, 229 insertions(+), 10 deletions(-)

diff --git a/babel/src/test/resources/sql/big-query.iq 
b/babel/src/test/resources/sql/big-query.iq
index 82a37386e9..1e2552a699 100755
--- a/babel/src/test/resources/sql/big-query.iq
+++ b/babel/src/test/resources/sql/big-query.iq
@@ -600,6 +600,74 @@ FROM t;
 !ok
 !}
 
+#####################################################################
+# SAFE_ADD
+#
+# SAFE_ADD(value1, value2)
+#
+# Equivalent to the addition operator (+), but returns NULL if 
overflow/underflow occurs.
+SELECT SAFE_ADD(5, 4) as result;
++--------+
+| result |
++--------+
+|      9 |
++--------+
+(1 row)
+
+!ok
+
+# Overflow occurs if result is greater than 2^63 - 1
+SELECT SAFE_ADD(9223372036854775807, 2) as overflow_result;
++-----------------+
+| overflow_result |
++-----------------+
+|                 |
++-----------------+
+(1 row)
+
+!ok
+
+# Underflow occurs if result is less than -2^63
+SELECT SAFE_ADD(-9223372036854775806, -3) as underflow_result;
++------------------+
+| underflow_result |
++------------------+
+|                  |
++------------------+
+(1 row)
+
+!ok
+
+SELECT SAFE_ADD(CAST(1.7e308 as DOUBLE), CAST(1.7e308 as DOUBLE)) as 
double_overflow;
++-----------------+
+| double_overflow |
++-----------------+
+|                 |
++-----------------+
+(1 row)
+
+!ok
+
+SELECT SAFE_ADD(9, cast(9.999999999999999999e75 as DECIMAL(38, 19))) as 
decimal_overflow;
++------------------+
+| decimal_overflow |
++------------------+
+|                  |
++------------------+
+(1 row)
+
+!ok
+
+# NaN arguments should return NaN
+SELECT SAFE_ADD(CAST('NaN' AS DOUBLE), CAST(3 as BIGINT)) as NaN_result;
++------------+
+| NaN_result |
++------------+
+|        NaN |
++------------+
+(1 row)
+
+!ok
 #####################################################################
 # SAFE_MULTIPLY
 #
diff --git 
a/core/src/main/java/org/apache/calcite/adapter/enumerable/RexImpTable.java 
b/core/src/main/java/org/apache/calcite/adapter/enumerable/RexImpTable.java
index c23c2f8678..80c1eee8cf 100644
--- a/core/src/main/java/org/apache/calcite/adapter/enumerable/RexImpTable.java
+++ b/core/src/main/java/org/apache/calcite/adapter/enumerable/RexImpTable.java
@@ -223,6 +223,7 @@ import static 
org.apache.calcite.sql.fun.SqlLibraryOperators.REVERSE;
 import static org.apache.calcite.sql.fun.SqlLibraryOperators.RIGHT;
 import static org.apache.calcite.sql.fun.SqlLibraryOperators.RLIKE;
 import static org.apache.calcite.sql.fun.SqlLibraryOperators.RPAD;
+import static org.apache.calcite.sql.fun.SqlLibraryOperators.SAFE_ADD;
 import static org.apache.calcite.sql.fun.SqlLibraryOperators.SAFE_CAST;
 import static org.apache.calcite.sql.fun.SqlLibraryOperators.SAFE_MULTIPLY;
 import static org.apache.calcite.sql.fun.SqlLibraryOperators.SAFE_OFFSET;
@@ -624,7 +625,8 @@ public class RexImpTable {
       defineMethod(TRUNC, "struncate", NullPolicy.STRICT);
       defineMethod(TRUNCATE, "struncate", NullPolicy.STRICT);
 
-      map.put(SAFE_MULTIPLY, new SafeArithmeticImplementor());
+      map.put(SAFE_ADD, new SafeArithmeticImplementor("safeAdd"));
+      map.put(SAFE_MULTIPLY, new SafeArithmeticImplementor("safeMultiply"));
 
       map.put(PI, new PiImplementor());
       return populate2();
@@ -2395,15 +2397,15 @@ public class RexImpTable {
 
   /** Implementor for the {@code SAFE_MULTIPLY} function. */
   private static class SafeArithmeticImplementor extends MethodNameImplementor 
{
-    SafeArithmeticImplementor() {
-      super("safeMultiply", NullPolicy.STRICT, false);
+    SafeArithmeticImplementor(String methodName) {
+      super(methodName, NullPolicy.STRICT, false);
     }
 
     @Override Expression implementSafe(final RexToLixTranslator translator,
         final RexCall call, final List<Expression> argValueList) {
       Expression arg0 = convertType(argValueList.get(0), call.operands.get(0));
       Expression arg1 = convertType(argValueList.get(1), call.operands.get(1));
-      return Expressions.call(SqlFunctions.class, "safeMultiply", arg0, arg1);
+      return Expressions.call(SqlFunctions.class, methodName, arg0, arg1);
     }
 
     // Because BigQuery treats all int types as aliases for BIGINT (Java's 
long)
diff --git a/core/src/main/java/org/apache/calcite/runtime/SqlFunctions.java 
b/core/src/main/java/org/apache/calcite/runtime/SqlFunctions.java
index 9374b31054..129b45fa7a 100644
--- a/core/src/main/java/org/apache/calcite/runtime/SqlFunctions.java
+++ b/core/src/main/java/org/apache/calcite/runtime/SqlFunctions.java
@@ -1735,6 +1735,61 @@ public class SqlFunctions {
     throw notArithmetic("*", b0, b1);
   }
 
+  /** SQL <code>SAFE_ADD</code> function applied to long values. */
+  public static @Nullable Long safeAdd(long b0, long b1) {
+    try {
+      return Math.addExact(b0, b1);
+    } catch (ArithmeticException e) {
+      return null;
+    }
+  }
+
+  /** SQL <code>SAFE_ADD</code> function applied to long and BigDecimal 
values. */
+  public static @Nullable BigDecimal safeAdd(long b0, BigDecimal b1) {
+    BigDecimal ans = BigDecimal.valueOf(b0).add(b1);
+    return safeDecimal(ans) ? ans : null;
+  }
+
+  /** SQL <code>SAFE_ADD</code> function applied to BigDecimal and long 
values. */
+  public static @Nullable BigDecimal safeAdd(BigDecimal b0, long b1) {
+    return safeAdd(b1, b0);
+  }
+
+  /** SQL <code>SAFE_ADD</code> function applied to BigDecimal values. */
+  public static @Nullable BigDecimal safeAdd(BigDecimal b0, BigDecimal b1) {
+    BigDecimal ans = b0.add(b1);
+    return safeDecimal(ans) ? ans : null;
+  }
+
+  /** SQL <code>SAFE_ADD</code> function applied to double and long values. */
+  public static @Nullable Double safeAdd(double b0, long b1) {
+    double ans = b0 + b1;
+    return safeDouble(ans) || !Double.isFinite(b0) ? ans : null;
+  }
+
+  /** SQL <code>SAFE_ADD</code> function applied to long and double values. */
+  public static @Nullable Double safeAdd(long b0, double b1) {
+    return safeAdd(b1, b0);
+  }
+
+  /** SQL <code>SAFE_ADD</code> function applied to double and BigDecimal 
values. */
+  public static @Nullable Double safeAdd(double b0, BigDecimal b1) {
+    double ans = b0 + b1.doubleValue();
+    return safeDouble(ans) || !Double.isFinite(b0) ? ans : null;
+  }
+
+  /** SQL <code>SAFE_ADD</code> function applied to BigDecimal and double 
values. */
+  public static @Nullable Double safeAdd(BigDecimal b0, double b1) {
+    return safeAdd(b1, b0);
+  }
+
+  /** SQL <code>SAFE_ADD</code> function applied to double values. */
+  public static @Nullable Double safeAdd(double b0, double b1) {
+    double ans = b0 + b1;
+    boolean isFinite = Double.isFinite(b0) && Double.isFinite(b1);
+    return safeDouble(ans) || !isFinite ? ans : null;
+  }
+
   /** SQL <code>SAFE_MULTIPLY</code> function applied to long values. */
   public static @Nullable Long safeMultiply(long b0, long b1) {
     try {
@@ -1752,8 +1807,7 @@ public class SqlFunctions {
 
   /** SQL <code>SAFE_MULTIPLY</code> function applied to BigDecimal and long 
values. */
   public static @Nullable BigDecimal safeMultiply(BigDecimal b0, long b1) {
-    BigDecimal ans = b0.multiply(BigDecimal.valueOf(b1));
-    return safeDecimal(ans) ? ans : null;
+    return safeMultiply(b1, b0);
   }
 
   /** SQL <code>SAFE_MULTIPLY</code> function applied to BigDecimal values. */
@@ -1770,8 +1824,7 @@ public class SqlFunctions {
 
   /** SQL <code>SAFE_MULTIPLY</code> function applied to long and double 
values. */
   public static @Nullable Double safeMultiply(long b0, double b1) {
-    double ans = b0 * b1;
-    return safeDouble(ans) || !Double.isFinite(b1) ? ans : null;
+    return safeMultiply(b1, b0);
   }
 
   /** SQL <code>SAFE_MULTIPLY</code> function applied to double and BigDecimal 
values. */
@@ -1782,8 +1835,7 @@ public class SqlFunctions {
 
   /** SQL <code>SAFE_MULTIPLY</code> function applied to BigDecimal and double 
values. */
   public static @Nullable Double safeMultiply(BigDecimal b0, double b1) {
-    double ans = b0.doubleValue() * b1;
-    return safeDouble(ans) || !Double.isFinite(b1) ? ans : null;
+    return safeMultiply(b1, b0);
   }
 
   /** SQL <code>SAFE_MULTIPLY</code> function applied to double values. */
diff --git 
a/core/src/main/java/org/apache/calcite/sql/fun/SqlLibraryOperators.java 
b/core/src/main/java/org/apache/calcite/sql/fun/SqlLibraryOperators.java
index 6976d52cf8..0b782d3458 100644
--- a/core/src/main/java/org/apache/calcite/sql/fun/SqlLibraryOperators.java
+++ b/core/src/main/java/org/apache/calcite/sql/fun/SqlLibraryOperators.java
@@ -1682,6 +1682,15 @@ public abstract class SqlLibraryOperators {
           OperandTypes.family(SqlTypeFamily.TIMESTAMP, SqlTypeFamily.TIMESTAMP,
               SqlTypeFamily.ANY));
 
+  /** The "SAFE_ADD(numeric1, numeric2)" function; equivalent to the {@code +} 
operator but
+   * returns null if overflow occurs. */
+  @LibraryOperator(libraries = {BIG_QUERY})
+  public static final SqlFunction SAFE_ADD =
+      SqlBasicFunction.create("SAFE_ADD",
+          ReturnTypes.SUM_FORCE_NULLABLE,
+          OperandTypes.NUMERIC_NUMERIC,
+          SqlFunctionCategory.NUMERIC);
+
   /** The "SAFE_MULTIPLY(numeric1, numeric2)" function; equivalent to the 
{@code *} operator but
    * returns null if overflow occurs. */
   @LibraryOperator(libraries = {BIG_QUERY})
diff --git a/core/src/main/java/org/apache/calcite/sql/type/ReturnTypes.java 
b/core/src/main/java/org/apache/calcite/sql/type/ReturnTypes.java
index 8bf20706f8..acdd678a9d 100644
--- a/core/src/main/java/org/apache/calcite/sql/type/ReturnTypes.java
+++ b/core/src/main/java/org/apache/calcite/sql/type/ReturnTypes.java
@@ -843,6 +843,15 @@ public abstract class ReturnTypes {
   public static final SqlReturnTypeInference DECIMAL_SUM_NULLABLE =
       DECIMAL_SUM.andThen(SqlTypeTransforms.TO_NULLABLE);
 
+  /**
+   * Same as {@link #DECIMAL_SUM_NULLABLE} but returns with nullability if any 
of
+   * the operands is nullable or the operation results in overflow by using
+   * {@link org.apache.calcite.sql.type.SqlTypeTransforms#FORCE_NULLABLE}. 
Also handles
+   * addition for integers, not just decimals.
+   */
+  public static final SqlReturnTypeInference SUM_FORCE_NULLABLE =
+      
DECIMAL_SUM_NULLABLE.orElse(LEAST_RESTRICTIVE).andThen(SqlTypeTransforms.FORCE_NULLABLE);
+
   /**
    * Type-inference strategy whereby the result type of a call is
    * {@link #DECIMAL_SUM_NULLABLE} with a fallback to {@link 
#LEAST_RESTRICTIVE}
diff --git a/site/_docs/reference.md b/site/_docs/reference.md
index 7daa15b7b2..791fce5b4f 100644
--- a/site/_docs/reference.md
+++ b/site/_docs/reference.md
@@ -2788,6 +2788,7 @@ BigQuery's type system uses confusingly different names 
for types and functions:
 | h s | string1 NOT RLIKE string2                    | Whether *string1* does 
not match regex pattern *string2* (similar to `NOT LIKE`, but uses Java regex)
 | b o | RPAD(string, length[, pattern ])             | Returns a string or 
bytes value that consists of *string* appended to *length* with *pattern*
 | b o | RTRIM(string)                                | Returns *string* with 
all blanks removed from the end
+| b | SAFE_ADD(numeric1, numeric2)                   | Returns *numeric1* + 
*numeric2*, or NULL on overflow
 | b | SAFE_CAST(value AS type)                       | Converts *value* to 
*type*, returning NULL if conversion fails
 | b | SAFE_MULTIPLY(numeric1, numeric2)              | Returns *numeric1* * 
*numeric2*, or NULL on overflow
 | b | SAFE_OFFSET(index)                             | Similar to `OFFSET` 
except null is returned if *index* is out of bounds
diff --git a/testkit/src/main/java/org/apache/calcite/test/SqlOperatorTest.java 
b/testkit/src/main/java/org/apache/calcite/test/SqlOperatorTest.java
index 1edd551984..5bf8ae0aeb 100644
--- a/testkit/src/main/java/org/apache/calcite/test/SqlOperatorTest.java
+++ b/testkit/src/main/java/org/apache/calcite/test/SqlOperatorTest.java
@@ -7252,6 +7252,84 @@ public class SqlOperatorTest {
     f.checkNull("truncate(cast(null as double))");
   }
 
+  @Test void testSafeAddFunc() {
+    final SqlOperatorFixture f0 = 
fixture().setFor(SqlLibraryOperators.SAFE_ADD);
+    f0.checkFails("^safe_add(2, 3)^",
+        "No match found for function signature "
+        + "SAFE_ADD\\(<NUMERIC>, <NUMERIC>\\)", false);
+    final SqlOperatorFixture f = f0.withLibrary(SqlLibrary.BIG_QUERY);
+    // Basic test for each of the 9 2-permutations of BIGINT, DECIMAL, and 
FLOAT
+    f.checkScalar("safe_add(cast(20 as bigint), cast(20 as bigint))",
+        "40", "BIGINT");
+    f.checkScalar("safe_add(cast(20 as bigint), cast(1.2345 as decimal(5,4)))",
+        "21.2345", "DECIMAL(19, 4)");
+    f.checkScalar("safe_add(cast(1.2345 as decimal(5,4)), cast(20 as bigint))",
+        "21.2345", "DECIMAL(19, 4)");
+    f.checkScalar("safe_add(cast(1.2345 as decimal(5,4)), "
+        + "cast(2.0 as decimal(2, 1)))", "3.2345", "DECIMAL(6, 4)");
+    f.checkScalar("safe_add(cast(3 as double), cast(3 as bigint))",
+        "6.0", "DOUBLE");
+    f.checkScalar("safe_add(cast(3 as bigint), cast(3 as double))",
+        "6.0", "DOUBLE");
+    f.checkScalar("safe_add(cast(3 as double), cast(1.2345 as decimal(5, 4)))",
+        "4.2345", "DOUBLE");
+    f.checkScalar("safe_add(cast(1.2345 as decimal(5, 4)), cast(3 as double))",
+        "4.2345", "DOUBLE");
+    f.checkScalar("safe_add(cast(3 as double), cast(3 as double))",
+        "6.0", "DOUBLE");
+    // Tests for + and - Infinity
+    f.checkScalar("safe_add(cast('Infinity' as double), cast(3 as double))",
+        "Infinity", "DOUBLE");
+    f.checkScalar("safe_add(cast('-Infinity' as double), cast(3 as double))",
+        "-Infinity", "DOUBLE");
+    f.checkScalar("safe_add(cast('-Infinity' as double), "
+        + "cast('Infinity' as double))", "NaN", "DOUBLE");
+    // Tests for NaN
+    f.checkScalar("safe_add(cast('NaN' as double), cast(3 as bigint))",
+        "NaN", "DOUBLE");
+    f.checkScalar("safe_add(cast('NaN' as double), cast(1.23 as decimal(3, 
2)))",
+        "NaN", "DOUBLE");
+    f.checkScalar("safe_add(cast('NaN' as double), cast('Infinity' as 
double))",
+        "NaN", "DOUBLE");
+    f.checkScalar("safe_add(cast(3 as bigint), cast('NaN' as double))",
+        "NaN", "DOUBLE");
+    f.checkScalar("safe_add(cast(1.23 as decimal(3, 2)), cast('NaN' as 
double))",
+        "NaN", "DOUBLE");
+    // Overflow test for each pairing
+    f.checkNull("safe_add(cast(20 as bigint), "
+        + "cast(9223372036854775807 as bigint))");
+    f.checkNull("safe_add(cast(-20 as bigint), "
+        + "cast(-9223372036854775807 as bigint))");
+    f.checkNull("safe_add(9, cast(9.999999999999999999e75 as DECIMAL(38, 
19)))");
+    f.checkNull("safe_add(-9, cast(-9.999999999999999999e75 as DECIMAL(38, 
19)))");
+    f.checkNull("safe_add(cast(9.999999999999999999e75 as DECIMAL(38, 19)), 
9)");
+    f.checkNull("safe_add(cast(-9.999999999999999999e75 as DECIMAL(38, 19)), 
-9)");
+    f.checkNull("safe_add(cast(9.9e75 as DECIMAL(76, 0)), "
+        + "cast(9.9e75 as DECIMAL(76, 0)))");
+    f.checkNull("safe_add(cast(-9.9e75 as DECIMAL(76, 0)), "
+        + "cast(-9.9e75 as DECIMAL(76, 0)))");
+    f.checkNull("safe_add(cast(1.7976931348623157e308 as double), "
+        + "cast(9.9e7 as decimal(76, 0)))");
+    f.checkNull("safe_add(cast(-1.7976931348623157e308 as double), "
+        + "cast(-9.9e7 as decimal(76, 0)))");
+    f.checkNull("safe_add(cast(9.9e7 as decimal(76, 0)), "
+        + "cast(1.7976931348623157e308 as double))");
+    f.checkNull("safe_add(cast(-9.9e7 as decimal(76, 0)), "
+        + "cast(-1.7976931348623157e308 as double))");
+    f.checkNull("safe_add(cast(1.7976931348623157e308 as double), cast(3 as 
bigint))");
+    f.checkNull("safe_add(cast(-1.7976931348623157e308 as double), "
+        + "cast(-3 as bigint))");
+    f.checkNull("safe_add(cast(3 as bigint), cast(1.7976931348623157e308 as 
double))");
+    f.checkNull("safe_add(cast(-3 as bigint), "
+        + "cast(-1.7976931348623157e308 as double))");
+    f.checkNull("safe_add(cast(3 as double), cast(1.7976931348623157e308 as 
double))");
+    f.checkNull("safe_add(cast(-3 as double), "
+        + "cast(-1.7976931348623157e308 as double))");
+    // Check that null argument retuns null
+    f.checkNull("safe_add(cast(null as double), cast(3 as bigint))");
+    f.checkNull("safe_add(cast(3 as double), cast(null as bigint))");
+  }
+
   @Test void testSafeMultiplyFunc() {
     final SqlOperatorFixture f0 = 
fixture().setFor(SqlLibraryOperators.SAFE_MULTIPLY);
     f0.checkFails("^safe_multiply(2, 3)^",

Reply via email to