This is an automated email from the ASF dual-hosted git repository.

abhishek pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/druid.git


The following commit(s) were added to refs/heads/master by this push:
     new f3996b96ffd Fixes for safe_divide with vectorize and datatypes (#15839)
f3996b96ffd is described below

commit f3996b96ffdf13f49de9fc7273e0643a90dd7032
Author: Soumyava <[email protected]>
AuthorDate: Thu Feb 8 01:10:42 2024 -0800

    Fixes for safe_divide with vectorize and datatypes (#15839)
    
    * Fix for save_divide with vectorize
    
    * More fixes
    
    * Update to use expr.eval(null) for both cases when denominator is 0
---
 .../java/org/apache/druid/math/expr/Function.java  | 11 +++--
 .../org/apache/druid/math/expr/FunctionTest.java   |  9 ++--
 .../builtin/SafeDivideOperatorConversion.java      |  8 ++--
 .../apache/druid/sql/calcite/CalciteQueryTest.java | 25 +++++++++++
 .../druid/sql/calcite/CalciteSelectQueryTest.java  | 50 +++++++++++++++++++++-
 5 files changed, 91 insertions(+), 12 deletions(-)

diff --git a/processing/src/main/java/org/apache/druid/math/expr/Function.java 
b/processing/src/main/java/org/apache/druid/math/expr/Function.java
index 365a53d3362..2c8a26759f3 100644
--- a/processing/src/main/java/org/apache/druid/math/expr/Function.java
+++ b/processing/src/main/java/org/apache/druid/math/expr/Function.java
@@ -1173,14 +1173,17 @@ public interface Function extends NamedFunction
       );
     }
 
+    @Override
+    public boolean canVectorize(Expr.InputBindingInspector inspector, 
List<Expr> args)
+    {
+      return false;
+    }
+
     @Override
     protected ExprEval eval(final long x, final long y)
     {
       if (y == 0) {
-        if (x != 0) {
-          return ExprEval.ofLong(null);
-        }
-        return ExprEval.ofLong(0);
+        return ExprEval.ofLong(null);
       }
       return ExprEval.ofLong(x / y);
     }
diff --git 
a/processing/src/test/java/org/apache/druid/math/expr/FunctionTest.java 
b/processing/src/test/java/org/apache/druid/math/expr/FunctionTest.java
index 670dbe93e1f..0338efa4664 100644
--- a/processing/src/test/java/org/apache/druid/math/expr/FunctionTest.java
+++ b/processing/src/test/java/org/apache/druid/math/expr/FunctionTest.java
@@ -857,11 +857,14 @@ public class FunctionTest extends 
InitializedNullHandlingTest
     assertExpr("safe_divide(4.5, 2)", 2.25);
     assertExpr("safe_divide(3, 0)", null);
     assertExpr("safe_divide(1, 0.0)", null);
-    // NaN and Infinity cases
+    // NaN, Infinity and other weird cases
     assertExpr("safe_divide(NaN, 0.0)", null);
     assertExpr("safe_divide(0, NaN)", 0.0);
-    assertExpr("safe_divide(0, POSITIVE_INFINITY)", 
NullHandling.defaultLongValue());
-    assertExpr("safe_divide(POSITIVE_INFINITY,0)", 
NullHandling.defaultLongValue());
+    assertExpr("safe_divide(0, maxLong)", 0L);
+    assertExpr("safe_divide(maxLong,0)", null);
+    assertExpr("safe_divide(0.0, inf)", 0.0);
+    assertExpr("safe_divide(0.0, -inf)", -0.0);
+    assertExpr("safe_divide(0,0)", null);
   }
 
   @Test
diff --git 
a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/SafeDivideOperatorConversion.java
 
b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/SafeDivideOperatorConversion.java
index dd09feefadd..13c715316bb 100644
--- 
a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/SafeDivideOperatorConversion.java
+++ 
b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/SafeDivideOperatorConversion.java
@@ -22,8 +22,9 @@ package org.apache.druid.sql.calcite.expression.builtin;
 import org.apache.calcite.sql.SqlFunction;
 import org.apache.calcite.sql.SqlFunctionCategory;
 import org.apache.calcite.sql.SqlOperator;
-import org.apache.calcite.sql.type.OperandTypes;
 import org.apache.calcite.sql.type.ReturnTypes;
+import org.apache.calcite.sql.type.SqlTypeFamily;
+import org.apache.calcite.sql.type.SqlTypeTransforms;
 import org.apache.druid.java.util.common.StringUtils;
 import org.apache.druid.math.expr.Function;
 import org.apache.druid.sql.calcite.expression.DirectOperatorConversion;
@@ -33,9 +34,10 @@ public class SafeDivideOperatorConversion extends 
DirectOperatorConversion
 {
   private static final SqlFunction SQL_FUNCTION = OperatorConversions
       .operatorBuilder(StringUtils.toUpperCase(Function.SafeDivide.NAME))
-      .operandTypeChecker(OperandTypes.ANY_NUMERIC)
-      .returnTypeInference(ReturnTypes.QUOTIENT_NULLABLE)
+      .operandTypes(SqlTypeFamily.NUMERIC, SqlTypeFamily.NUMERIC)
+      
.returnTypeInference(ReturnTypes.LEAST_RESTRICTIVE.andThen(SqlTypeTransforms.FORCE_NULLABLE))
       .functionCategory(SqlFunctionCategory.USER_DEFINED_FUNCTION)
+      .requiredOperandCount(2)
       .build();
 
   public SafeDivideOperatorConversion()
diff --git 
a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java 
b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java
index c4a57774ca0..2ce2fff0431 100644
--- a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java
+++ b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java
@@ -571,6 +571,31 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
     );
   }
 
+  @Test
+  public void testSafeDivide()
+  {
+    skipVectorize();
+    cannotVectorize();
+    final Map<String, Object> context = new HashMap<>(QUERY_CONTEXT_DEFAULT);
+
+    testQuery(
+        "select count(*) c from foo where ((floor(safe_divide(cast(cast(m1 as 
char) as bigint), 2))) = 0)",
+        context,
+        ImmutableList.of(
+            Druids.newTimeseriesQueryBuilder()
+                .dataSource(CalciteTests.DATASOURCE1)
+                .intervals(querySegmentSpec(Filtration.eternity()))
+                .virtualColumns(expressionVirtualColumn("v0", 
"floor(safe_divide(CAST(CAST(\"m1\", 'STRING'), 'LONG'),2))", ColumnType.LONG))
+                .filters(equality("v0", 0L, ColumnType.LONG))
+                .granularity(Granularities.ALL)
+                .aggregators(new CountAggregatorFactory("a0"))
+                .context(context)
+                .build()
+        ),
+        ImmutableList.of(new Object[]{1L})
+    );
+  }
+
   @Test
   public void testGroupByLimitWrappingOrderByAgg()
   {
diff --git 
a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteSelectQueryTest.java 
b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteSelectQueryTest.java
index 2cf8296dfc0..6569b52a90a 100644
--- a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteSelectQueryTest.java
+++ b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteSelectQueryTest.java
@@ -482,6 +482,52 @@ public class CalciteSelectQueryTest extends 
BaseCalciteQueryTest
     );
   }
 
+  @Test
+  public void testSafeDivideWithoutTable()
+  {
+    skipVectorize();
+    cannotVectorize();
+    final Map<String, Object> context = new HashMap<>(QUERY_CONTEXT_DEFAULT);
+
+    testQuery(
+        "select SAFE_DIVIDE(0, 0), SAFE_DIVIDE(1,0), SAFE_DIVIDE(10,2.5), "
+        + " SAFE_DIVIDE(10.5,3.5), SAFE_DIVIDE(10.5,3), SAFE_DIVIDE(10,2)",
+        context,
+        ImmutableList.of(
+            Druids.newScanQueryBuilder()
+                  .dataSource(
+                      InlineDataSource.fromIterable(
+                          ImmutableList.of(
+                              new Object[]{0L}
+                          ),
+                          RowSignature.builder().add("ZERO", 
ColumnType.LONG).build()
+                      )
+                  )
+                  .intervals(querySegmentSpec(Filtration.eternity()))
+                  .columns("v0", "v1", "v2", "v3", "v4")
+                  .virtualColumns(
+                      expressionVirtualColumn("v0", 
NullHandling.sqlCompatible() ? "null" : "0", ColumnType.LONG),
+                      expressionVirtualColumn("v1", "4.0", ColumnType.DOUBLE),
+                      expressionVirtualColumn("v2", "3.0", ColumnType.DOUBLE),
+                      expressionVirtualColumn("v3", "3.5", ColumnType.DOUBLE),
+                      expressionVirtualColumn("v4", "5", ColumnType.LONG)
+                  )
+                  
.resultFormat(ScanQuery.ResultFormat.RESULT_FORMAT_COMPACTED_LIST)
+                  .legacy(false)
+                  .context(context)
+                  .build()
+        ),
+        ImmutableList.of(new Object[]{
+            NullHandling.sqlCompatible() ? null : 0,
+            NullHandling.sqlCompatible() ? null : 0,
+            4.0D,
+            3.0D,
+            3.5D,
+            5
+        })
+    );
+  }
+
   @Test
   public void testSafeDivideExpressions()
   {
@@ -498,8 +544,8 @@ public class CalciteSelectQueryTest extends 
BaseCalciteQueryTest
     } else {
       expected = ImmutableList.of(
           new Object[]{null, null, null, 7.0F},
-          new Object[]{1.0F, 1L, 1.0, 3253230.0F},
-          new Object[]{0.0F, 0L, 0.0, 0.0F},
+          new Object[]{1.0F, 1L, 1.0D, 3253230.0F},
+          new Object[]{0.0F, null, 0.0D, 0.0F},
           new Object[]{null, null, null, null},
           new Object[]{null, null, null, null},
           new Object[]{null, null, null, null}


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to