This is an automated email from the ASF dual-hosted git repository.

zabetak pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/hive.git


The following commit(s) were added to refs/heads/master by this push:
     new 301bfb57f67 HIVE-29424: CBO plans should use histogram statistics for 
range predicates with a CAST (#6293)
301bfb57f67 is described below

commit 301bfb57f67e10e92fb84569cbf62066d056886a
Author: Thomas Rebele <[email protected]>
AuthorDate: Fri Mar 13 08:56:46 2026 +0100

    HIVE-29424: CBO plans should use histogram statistics for range predicates 
with a CAST (#6293)
---
 .../calcite/stats/FilterSelectivityEstimator.java  | 478 +++++++++++++++++---
 .../stats/TestFilterSelectivityEstimator.java      | 500 ++++++++++++++++++++-
 2 files changed, 892 insertions(+), 86 deletions(-)

diff --git 
a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/stats/FilterSelectivityEstimator.java
 
b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/stats/FilterSelectivityEstimator.java
index b18c525c884..e0f8eb41bf3 100644
--- 
a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/stats/FilterSelectivityEstimator.java
+++ 
b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/stats/FilterSelectivityEstimator.java
@@ -22,14 +22,19 @@
 import java.util.Collections;
 import java.util.GregorianCalendar;
 import java.util.List;
+import java.util.Objects;
+import java.util.Optional;
 import java.util.Set;
 
+import com.google.common.collect.BoundType;
+import com.google.common.collect.Range;
 import org.apache.calcite.plan.RelOptUtil;
 import org.apache.calcite.plan.RelOptUtil.InputReferencedVisitor;
 import org.apache.calcite.rel.RelNode;
 import org.apache.calcite.rel.core.Filter;
 import org.apache.calcite.rel.core.Project;
 import org.apache.calcite.rel.metadata.RelMetadataQuery;
+import org.apache.calcite.rel.type.RelDataType;
 import org.apache.calcite.rex.RexBuilder;
 import org.apache.calcite.rex.RexCall;
 import org.apache.calcite.rex.RexInputRef;
@@ -40,6 +45,7 @@
 import org.apache.calcite.rex.RexVisitorImpl;
 import org.apache.calcite.sql.SqlKind;
 import org.apache.calcite.sql.type.SqlTypeName;
+import org.apache.calcite.sql.type.SqlTypeUtil;
 import org.apache.calcite.util.ImmutableBitSet;
 import org.apache.datasketches.kll.KllFloatsSketch;
 import org.apache.datasketches.memory.Memory;
@@ -184,91 +190,384 @@ public Double visitCall(RexCall call) {
     return selectivity;
   }
 
-  private double computeRangePredicateSelectivity(RexCall call, SqlKind op) {
-    final boolean isLiteralLeft = 
call.getOperands().get(0).getKind().equals(SqlKind.LITERAL);
-    final boolean isLiteralRight = 
call.getOperands().get(1).getKind().equals(SqlKind.LITERAL);
-    final boolean isInputRefLeft = 
call.getOperands().get(0).getKind().equals(SqlKind.INPUT_REF);
-    final boolean isInputRefRight = 
call.getOperands().get(1).getKind().equals(SqlKind.INPUT_REF);
+  /**
+   * Return whether the expression is a removable cast based on stats and type 
bounds.
+   *
+   * <p>
+   * There are two main categories of CAST behavior:
+   * <ul>
+   *   <li>Non-representable values will be cast to NULL (c1). As NULL does 
not fulfill the predicate,
+   *     these non-representable values will need to be excluded when 
estimating the selectivity.
+   *     Therefore, the selectivity can be estimated by restricting the 
predicate range to the range of possible
+   *     values of the target type. There might be some minor changes to the 
value due to rounding, which can be
+   *     counterbalanced by adjusting the range predicate slightly.
+   *     <br>
+   *     This category applies in most cases, e.g., when casting
+   *     <ul>
+   *       <li>DECIMAL to an integer type</li>
+   *       <li>an integer type to DECIMAL</li>
+   *       <li>DECIMAL to DECIMAL</li>
+   *       <li>an integer type to a larger integer type, e.g., TINYINT to 
SMALLINT.
+   *         All values are representable, so the condition (c1) is trivially 
fulfilled.
+   *       </li>
+   *     </ul>
+   *   </li>
+   *   <li>If a value cannot be represented by the cast, the value is MODIFIED 
SUBSTANTIALLY.
+   *     For example, CAST(128 as TINYINT) = -128.
+   *     In principle, it is possible to estimate a selectivity using 
histograms.
+   *     However, the implementation would likely be quite complex, the 
estimate of low quality,
+   *     and the gain quite limited, so currently we consider these CASTs as 
non-removable.
+   *   </li>
+   * </ul>
+   * </p>
+   *
+   * @param exp       the expression to check
+   * @param tableScan the table that provides the statistics
+   * @return true if the expression is a removable cast, false otherwise
+   */
+  private boolean isRemovableCast(RexNode exp, HiveTableScan tableScan) {
+    if(SqlKind.CAST != exp.getKind()) {
+      return false;
+    }
+    RexCall cast = (RexCall) exp;
+    RexNode op0 = cast.getOperands().getFirst();
+    if (!(op0 instanceof RexInputRef)) {
+      return false;
+    }
 
-    if (childRel instanceof HiveTableScan && isLiteralLeft != isLiteralRight 
&& isInputRefLeft != isInputRefRight) {
-      final HiveTableScan t = (HiveTableScan) childRel;
-      final int inputRefIndex = ((RexInputRef) 
call.getOperands().get(isInputRefLeft ? 0 : 1)).getIndex();
-      final List<ColStatistics> colStats = 
t.getColStat(Collections.singletonList(inputRefIndex));
+    SqlTypeName sourceType = op0.getType().getSqlTypeName();
+    SqlTypeName targetType = cast.getType().getSqlTypeName();
+
+    switch (sourceType) {
+    case TINYINT, SMALLINT, INTEGER, BIGINT:
+      switch (targetType) {// additional checks are needed
+      case TINYINT, SMALLINT, INTEGER, BIGINT:
+        return isRemovableIntegerCast(cast, op0, tableScan);
+      case FLOAT, DOUBLE, DECIMAL:
+        return true;
+      default:
+        return false;
+      }
+    case FLOAT, DOUBLE, DECIMAL:
+      switch (targetType) {
+      // these CASTs do not show a modulo behavior, so it's ok to remove such 
a cast
+      case TINYINT, SMALLINT, INTEGER, BIGINT, FLOAT, DOUBLE, DECIMAL:
+        return true;
+      default:
+        return false;
+      }
+    case TIMESTAMP, DATE:
+      switch (targetType) {
+      case TIMESTAMP, DATE:
+        return true;
+      default:
+        return false;
+      }
+      // unknown type, do not remove the cast
+    default:
+      return false;
+    }
+  }
 
-      if (!colStats.isEmpty() && isHistogramAvailable(colStats.get(0))) {
-        final KllFloatsSketch kll = 
KllFloatsSketch.heapify(Memory.wrap(colStats.get(0).getHistogram()));
-        final Object boundValueObject = ((RexLiteral) 
call.getOperands().get(isLiteralLeft ? 0 : 1)).getValue();
-        final SqlTypeName typeName = call.getOperands().get(isInputRefLeft ? 0 
: 1).getType().getSqlTypeName();
-        float value = extractLiteral(typeName, boundValueObject);
-        boolean closedBound = op.equals(SqlKind.LESS_THAN_OR_EQUAL) || 
op.equals(SqlKind.GREATER_THAN_OR_EQUAL);
-
-        double selectivity;
-        if (op.equals(SqlKind.LESS_THAN_OR_EQUAL) || 
op.equals(SqlKind.LESS_THAN)) {
-          selectivity = closedBound ? lessThanOrEqualSelectivity(kll, value) : 
lessThanSelectivity(kll, value);
-        } else {
-          selectivity = closedBound ? greaterThanOrEqualSelectivity(kll, 
value) : greaterThanSelectivity(kll, value);
-        }
+  private static boolean isRemovableIntegerCast(RexCall cast, RexNode op0, 
HiveTableScan tableScan) {
+    int inputIndex = ((RexInputRef) op0).getIndex();
+    final List<ColStatistics> colStats = 
tableScan.getColStat(Collections.singletonList(inputIndex));
+    if (colStats.isEmpty()) {
+      return false;
+    }
+
+    // If the source type is completely within the target type, the cast is 
lossless
+    Range<Float> targetRange = getRangeOfType(cast.getType());
+    Range<Float> sourceRange = getRangeOfType(op0.getType());
+    if (targetRange.encloses(sourceRange)) {
+      return true;
+    }
 
-        // selectivity does not account for null values, we multiply for the 
number of non-null values (getN)
-        // and we divide by the total (non-null + null values) to get the 
overall selectivity.
-        //
-        // Example: consider a filter "col < 3", and the following table rows:
-        //  _____
-        // | col |
-        // |_____|
-        // |1    |
-        // |null |
-        // |null |
-        // |3    |
-        // |4    |
-        // -------
-        // kll.getN() would be 3, selectivity 1/3, t.getTable().getRowCount() 5
-        // so the final result would be 3 * 1/3 / 5 = 1/5, as expected.
-        return kll.getN() * selectivity / t.getTable().getRowCount();
+    // Check that the possible values of the input column are all within the 
type range of the cast
+    // otherwise the CAST introduces some modulo-like behavior
+    ColStatistics colStat = colStats.getFirst();
+    ColStatistics.Range colRange = colStat.getRange();
+    if (colRange == null || colRange.minValue == null || colRange.maxValue == 
null) {
+      return false;
+    }
+
+    // are all values of the input column accepted by the cast?
+    SqlTypeName targetType = cast.getType().getSqlTypeName();
+    double min = ((Number) targetType.getLimit(false, 
SqlTypeName.Limit.OVERFLOW, false, -1, -1)).doubleValue();
+    double max = ((Number) targetType.getLimit(true, 
SqlTypeName.Limit.OVERFLOW, false, -1, -1)).doubleValue();
+    return min < colRange.minValue.doubleValue() && 
colRange.maxValue.doubleValue() < max;
+  }
+
+  /**
+   * Get the range of values that are rounded to valid values of a type.
+   *
+   * @param type the type
+   * @return the range of the type
+   */
+  private static Range<Float> getRangeOfType(RelDataType type) {
+    switch (type.getSqlTypeName()) {
+    // in case of integer types,
+    case TINYINT:
+      return Range.closed(-128.99998f, 127.99999f);
+    case SMALLINT:
+      return Range.closed(-32768.996f, 32767.998f);
+    case INTEGER:
+      return Range.closed(-2.1474836E9f, 2.1474836E9f);
+    case BIGINT, DATE, TIMESTAMP:
+      return Range.closed(-9.223372E18f, 9.223372E18f);
+    case DECIMAL:
+      return getRangeOfDecimalType(type);
+    case FLOAT, DOUBLE:
+      return Range.closed(-Float.MAX_VALUE, Float.MAX_VALUE);
+    default:
+      throw new IllegalStateException("Unsupported type: " + type);
+    }
+  }
+
+  private static Range<Float> getRangeOfDecimalType(RelDataType type) {
+    // values outside the representable range are cast to NULL, so adapt the 
boundaries
+    int digits = type.getPrecision() - type.getScale();
+    // the cast does some rounding, i.e., CAST(99.9499 AS DECIMAL(3,1)) = 99.9
+    // but CAST(99.95 AS DECIMAL(3,1)) = NULL
+    float adjust = (float) (5 * Math.pow(10, -(type.getScale() + 1)));
+    // the range of values supported by the type is interval 
[-typeRangeExtent, typeRangeExtent] (both inclusive)
+    // e.g., the typeRangeExt is 99.94999 for DECIMAL(3,1)
+    float typeRangeExtent = Math.nextDown((float) (Math.pow(10, digits) - 
adjust));
+    return Range.closed(-typeRangeExtent, typeRangeExtent);
+  }
+
+  /**
+   * Adjust the type boundaries if necessary.
+   *
+   * @param predicateRange boundaries of the range predicate
+   * @param type the type
+   * @param typeRange the boundaries of the type range
+   * @return the adjusted boundary
+   */
+  private static Range<Float> adjustRangeToType(Range<Float> predicateRange, 
RelDataType type, Range<Float> typeRange) {
+    // Adjusting empty ranges is not needed, and can also lead to invalid 
adjustments
+    if (predicateRange.isEmpty()) {
+      return predicateRange;
+    }
+    if (SqlTypeUtil.isExactNumeric(type)) {
+      // the original boundaries affect the rounding, so save them
+      boolean lowerInclusive = 
BoundType.CLOSED.equals(predicateRange.lowerBoundType());
+      boolean upperInclusive = 
BoundType.CLOSED.equals(predicateRange.upperBoundType());
+      // normalize the range to make the formulas easier
+      Range<Float> range = convertRangeToClosedOpen(predicateRange);
+      typeRange = convertRangeToClosedOpen(typeRange);
+      final float adjustedLower;
+      final float adjustedUpper;
+      if (type.getSqlTypeName() == SqlTypeName.DECIMAL) {
+        // The cast to DECIMAL rounds the value the same way as {@link 
RoundingMode#HALF_UP}.
+        // The boundaries are adjusted accordingly.
+        float adjust = (float) (5 * Math.pow(10, -(type.getScale() + 1)));
+        // the resulting value of +- adjust would be rounded up, so in some 
cases we need to use Math.nextDown
+        adjustedLower = lowerInclusive ? range.lowerEndpoint() - adjust : 
addAndDown(range.lowerEndpoint(), adjust);
+        adjustedUpper = upperInclusive ? addAndDown(range.upperEndpoint(), 
adjust) : range.upperEndpoint() - adjust;
+      } else {
+        // when casting a floating point, its values are rounded towards 0
+        // i.e, 10.99 is rounded to 10, and -10.99 is rounded to -10
+        // to take this into account, the predicate range is transformed in 
the following ways
+        // [10.0, 15.0] -> [10, 15.99999]
+        // (10.0, 15.0) -> [11, 14.99999]
+        // [10.2, 15.2] -> [11, 15.99999]
+        // (10.2, 15.2) -> [11, 15.99999]
+
+        // [-15.0, -10.0] -> [-15.9999, -10]
+        // (-15.0, -10.0) -> [-14.9999, -11]
+        // [-15.2, -10.2] -> [-15.9999, -11]
+        // (-15.2, -10.2) -> [-15.9999, -11]
+        adjustedLower = range.lowerEndpoint() >= 0 ? (float) 
Math.ceil(range.lowerEndpoint())
+            : Math.nextUp(-(float) 
Math.ceil(Math.nextUp(-range.lowerEndpoint())));
+        adjustedUpper = range.upperEndpoint() >= 0 ? Math.nextDown((float) 
Math.ceil(range.upperEndpoint()))
+            : Math.nextUp((float) -Math.ceil(-range.upperEndpoint()));
       }
+      predicateRange = makeRange(adjustedLower, adjustedUpper, BoundType.OPEN);
+    }
+    return typeRange.isConnected(predicateRange) ? 
typeRange.intersection(predicateRange) : Range.closedOpen(0f, 0f);
+  }
+
+  private static float addAndDown(float v, float positiveSummand) {
+    float r = v + positiveSummand;
+    if (r == v) {
+      // the result is below the resolution of float; do not return a value 
smaller than v
+      return r;
+    } else {
+      return Math.nextDown(r);
+    }
+  }
+
+  /**
+   * If the arguments lead to a valid range, it is returned, otherwise an 
empty range is returned.
+   */
+  private static Range<Float> makeRange(float lower, float upper, BoundType 
upperType) {
+    return lower > upper ? Range.closedOpen(0f, 0f) : Range.range(lower, 
BoundType.CLOSED, upper, upperType);
+  }
+
+  private double computeRangePredicateSelectivity(RexCall call, SqlKind op) {
+    double defaultSelectivity = ((double) 1 / (double) 3);
+    if (!(childRel instanceof HiveTableScan)) {
+      return defaultSelectivity;
+    }
+
+    // search for the literal
+    List<RexNode> operands = call.getOperands();
+    final Optional<Float> leftLiteral = extractLiteral(operands.get(0));
+    final Optional<Float> rightLiteral = extractLiteral(operands.get(1));
+    // ensure that there's exactly one literal
+    if ((leftLiteral.isPresent()) == (rightLiteral.isPresent())) {
+      return defaultSelectivity;
+    }
+    int literalOpIdx = leftLiteral.isPresent() ? 0 : 1;
+
+    // analyze the predicate
+    float value = leftLiteral.orElseGet(rightLiteral::get);
+    int boundaryIdx;
+    boolean openBound = op == SqlKind.LESS_THAN || op == SqlKind.GREATER_THAN;
+    switch (op) {
+    case LESS_THAN, LESS_THAN_OR_EQUAL:
+      boundaryIdx = literalOpIdx;
+      break;
+    case GREATER_THAN, GREATER_THAN_OR_EQUAL:
+      boundaryIdx = 1 - literalOpIdx;
+      break;
+    default:
+      return defaultSelectivity;
+    }
+    float[] boundaryValues = new float[] { Float.NEGATIVE_INFINITY, 
Float.POSITIVE_INFINITY };
+    BoundType[] inclusive = new BoundType[] { BoundType.CLOSED, 
BoundType.CLOSED };
+    boundaryValues[boundaryIdx] = value;
+    inclusive[boundaryIdx] = openBound ? BoundType.OPEN : BoundType.CLOSED;
+    Range<Float> boundaries = Range.range(boundaryValues[0], inclusive[0], 
boundaryValues[1], inclusive[1]);
+
+    // extract the column index from the other operator
+    final HiveTableScan scan = (HiveTableScan) childRel;
+    int inputRefOpIndex = 1 - literalOpIdx;
+    RexNode node = operands.get(inputRefOpIndex);
+    if (isRemovableCast(node, scan)) {
+      Range<Float> typeRange = getRangeOfType(node.getType());
+      boundaries = adjustRangeToType(boundaries, node.getType(), typeRange);
+
+      node = RexUtil.removeCast(node);
+    }
+
+    int inputRefIndex = -1;
+    if (node.getKind().equals(SqlKind.INPUT_REF)) {
+      inputRefIndex = ((RexInputRef) node).getIndex();
     }
-    return ((double) 1 / (double) 3);
+
+    if (inputRefIndex < 0) {
+      return defaultSelectivity;
+    }
+
+    final List<ColStatistics> colStats = 
scan.getColStat(Collections.singletonList(inputRefIndex));
+    if (colStats.isEmpty() || !isHistogramAvailable(colStats.get(0))) {
+      return defaultSelectivity;
+    }
+
+    final KllFloatsSketch kll = 
KllFloatsSketch.heapify(Memory.wrap(colStats.get(0).getHistogram()));
+    double rawSelectivity = rangedSelectivity(kll, boundaries);
+    return scaleSelectivityToNullableValues(kll, rawSelectivity, scan);
+  }
+
+  /**
+   * Adjust the selectivity estimate to take NULL values into account.
+   * <p>
+   * The rawSelectivity does not account for null values. We multiply with the 
number of non-null values (getN)
+   * and we divide by the total number (non-null + null values) to get the 
overall selectivity.
+   * <p>
+   * Example: consider a filter "col < 3", and the following table rows:
+   * <pre>
+   *  _____
+   * | col |
+   * |_____|
+   * |1    |
+   * |null |
+   * |null |
+   * |3    |
+   * |4    |
+   * -------
+   * </pre>
+   * kll.getN() would be 3, rawSelectivity 1/3, scan.getTable().getRowCount() 5
+   * so the final result would be 3 * 1/3 / 5 = 1/5, as expected.
+   */
+  private static double scaleSelectivityToNullableValues(KllFloatsSketch kll, 
double rawSelectivity,
+      HiveTableScan scan) {
+    if (scan.getTable() == null) {
+      return rawSelectivity;
+    }
+    return kll.getN() * rawSelectivity / scan.getTable().getRowCount();
   }
 
   private Double computeBetweenPredicateSelectivity(RexCall call) {
-    final boolean hasLiteralBool = 
call.getOperands().get(0).getKind().equals(SqlKind.LITERAL);
-    final boolean hasInputRef = 
call.getOperands().get(1).getKind().equals(SqlKind.INPUT_REF);
-    final boolean hasLiteralLeft = 
call.getOperands().get(2).getKind().equals(SqlKind.LITERAL);
-    final boolean hasLiteralRight = 
call.getOperands().get(3).getKind().equals(SqlKind.LITERAL);
+    if (!(childRel instanceof HiveTableScan)) {
+      return computeFunctionSelectivity(call);
+    }
 
-    if (childRel instanceof HiveTableScan && hasLiteralBool && hasInputRef && 
hasLiteralLeft && hasLiteralRight) {
-      final HiveTableScan t = (HiveTableScan) childRel;
-      final int inputRefIndex = ((RexInputRef) 
call.getOperands().get(1)).getIndex();
-      final List<ColStatistics> colStats = 
t.getColStat(Collections.singletonList(inputRefIndex));
+    List<RexNode> operands = call.getOperands();
+    final boolean hasLiteralBool = 
operands.get(0).getKind().equals(SqlKind.LITERAL);
+    Optional<Float> leftLiteral = extractLiteral(operands.get(2));
+    Optional<Float> rightLiteral = extractLiteral(operands.get(3));
+
+    if (hasLiteralBool && leftLiteral.isPresent() && rightLiteral.isPresent()) 
{
+      final HiveTableScan scan = (HiveTableScan) childRel;
+      float leftValue = leftLiteral.get();
+      float rightValue = rightLiteral.get();
+
+      boolean inverseBool = RexLiteral.booleanValue(operands.getFirst());
+      // when they are equal it's an equality predicate, we cannot handle it 
as "BETWEEN"
+      if (Objects.equals(leftValue, rightValue)) {
+        return inverseBool ? computeNotEqualitySelectivity(call) : 
computeFunctionSelectivity(call);
+      }
 
+      Range<Float> rangeBoundaries = makeRange(leftValue, rightValue, 
BoundType.CLOSED);
+      Range<Float> typeBoundaries = inverseBool ? 
Range.closed(Float.NEGATIVE_INFINITY, Float.POSITIVE_INFINITY) : null;
+
+      RexNode expr = operands.get(1); // expr to be checked by the BETWEEN
+      if (isRemovableCast(expr, scan)) {
+        typeBoundaries = getRangeOfType(expr.getType());
+        rangeBoundaries = adjustRangeToType(rangeBoundaries, expr.getType(), 
typeBoundaries);
+        expr = RexUtil.removeCast(expr);
+      }
+
+      int inputRefIndex = -1;
+      if (expr.getKind().equals(SqlKind.INPUT_REF)) {
+        inputRefIndex = ((RexInputRef) expr).getIndex();
+      }
+
+      if (inputRefIndex < 0) {
+        return computeFunctionSelectivity(call);
+      }
+
+      final List<ColStatistics> colStats = 
scan.getColStat(Collections.singletonList(inputRefIndex));
       if (!colStats.isEmpty() && isHistogramAvailable(colStats.get(0))) {
         final KllFloatsSketch kll = 
KllFloatsSketch.heapify(Memory.wrap(colStats.get(0).getHistogram()));
-        final SqlTypeName typeName = 
call.getOperands().get(1).getType().getSqlTypeName();
-        final Object inverseBoolValueObject = ((RexLiteral) 
call.getOperands().get(0)).getValue();
-        boolean inverseBool = 
Boolean.parseBoolean(inverseBoolValueObject.toString());
-        final Object leftBoundValueObject = ((RexLiteral) 
call.getOperands().get(2)).getValue();
-        float leftValue = extractLiteral(typeName, leftBoundValueObject);
-        final Object rightBoundValueObject = ((RexLiteral) 
call.getOperands().get(3)).getValue();
-        float rightValue = extractLiteral(typeName, rightBoundValueObject);
-        // when inverseBool == true, this is a NOT_BETWEEN and selectivity 
must be inverted
+        double rawSelectivity = rangedSelectivity(kll, rangeBoundaries);
         if (inverseBool) {
-          if (rightValue == leftValue) {
-            return computeNotEqualitySelectivity(call);
-          } else if (rightValue < leftValue) {
-            return 1.0;
-          }
-          return 1.0 - (kll.getN() * betweenSelectivity(kll, leftValue, 
rightValue) / t.getTable().getRowCount());
-        }
-        // when they are equal it's an equality predicate, we cannot handle it 
as "between"
-        if (Double.compare(leftValue, rightValue) != 0) {
-          return kll.getN() * betweenSelectivity(kll, leftValue, rightValue) / 
t.getTable().getRowCount();
+          // when inverseBool == true, this is a NOT_BETWEEN and selectivity 
must be inverted
+          // if there's a cast, the inversion is with respect to its codomain 
(range of the values of the cast)
+          double typeRangeSelectivity = rangedSelectivity(kll, typeBoundaries);
+          rawSelectivity = typeRangeSelectivity - rawSelectivity;
         }
+        return scaleSelectivityToNullableValues(kll, rawSelectivity, scan);
       }
     }
     return computeFunctionSelectivity(call);
   }
 
-  private float extractLiteral(SqlTypeName typeName, Object boundValueObject) {
+  private Optional<Float> extractLiteral(RexNode node) {
+    if (node.getKind() != SqlKind.LITERAL) {
+      return Optional.empty();
+    }
+    RexLiteral literal = (RexLiteral) node;
+    if (literal.getValue() == null) {
+      return Optional.empty();
+    }
+    return extractLiteral(literal.getTypeName(), literal.getValue());
+  }
+
+  private Optional<Float> extractLiteral(SqlTypeName typeName, Object 
boundValueObject) {
     final String boundValueString = boundValueObject.toString();
 
     float value;
@@ -299,10 +598,9 @@ private float extractLiteral(SqlTypeName typeName, Object 
boundValueObject) {
       value = ((GregorianCalendar) 
boundValueObject).toInstant().getEpochSecond();
       break;
     default:
-      throw new IllegalStateException(
-          "Unsupported type for comparator selectivity evaluation using 
histogram: " + typeName);
+      return Optional.empty();
     }
-    return value;
+    return Optional.of(value);
   }
 
   /**
@@ -470,7 +768,7 @@ private boolean isPartitionPredicate(RexNode expr, RelNode 
r) {
     } else if (r instanceof Filter) {
       return isPartitionPredicate(expr, ((Filter) r).getInput());
     } else if (r instanceof HiveTableScan) {
-      RelOptHiveTable table = (RelOptHiveTable) ((HiveTableScan) r).getTable();
+      RelOptHiveTable table = (RelOptHiveTable) r.getTable();
       ImmutableBitSet cols = RelOptUtil.InputFinder.bits(expr);
       return table.containsPartitionColumnsOnly(cols);
     }
@@ -489,7 +787,43 @@ public Double visitLiteral(RexLiteral literal) {
     return null;
   }
 
-  private static double rangedSelectivity(KllFloatsSketch kll, float val1, 
float val2) {
+  /**
+   * Returns the selectivity of a predicate "val1 &lt;= column &lt; val2".
+   * @param kll the sketch
+   * @param boundaries the boundaries
+   * @return the selectivity of "val1 &lt;= column &lt; val2"
+   */
+  private static double rangedSelectivity(KllFloatsSketch kll, Range<Float> 
boundaries) {
+    // convert the condition to a range val1 <= x < val2
+    Range<Float> closedOpen = convertRangeToClosedOpen(boundaries);
+    return rangedSelectivity(kll, closedOpen.lowerEndpoint(), 
closedOpen.upperEndpoint());
+  }
+
+  /**
+   * Normalizes the range to the form "val1 &lt;= column &lt; val2".
+   */
+  private static Range<Float> convertRangeToClosedOpen(Range<Float> 
boundaries) {
+    boolean leftClosed = BoundType.CLOSED.equals(boundaries.lowerBoundType());
+    boolean rightOpen = BoundType.OPEN.equals(boundaries.upperBoundType());
+    if (leftClosed && rightOpen) {
+      return boundaries;
+    }
+    float newLower = leftClosed ? boundaries.lowerEndpoint() : 
Math.nextUp(boundaries.lowerEndpoint());
+    float newUpper = rightOpen ? boundaries.upperEndpoint() : 
Math.nextUp(boundaries.upperEndpoint());
+    return Range.closedOpen(newLower, newUpper);
+  }
+
+  /**
+   * Returns the selectivity of a predicate "val1 &lt;= column &lt; val2".
+   * @param kll the sketch
+   * @param val1 lower bound (inclusive)
+   * @param val2 upper bound (exclusive)
+   * @return the selectivity of "val1 &lt;= column &lt; val2"
+   */
+  static double rangedSelectivity(KllFloatsSketch kll, float val1, float val2) 
{
+    if (val1 >= val2) {
+      return 0;
+    }
     float[] splitPoints = new float[] { val1, val2 };
     double[] boundaries = kll.getCDF(splitPoints, 
QuantileSearchCriteria.EXCLUSIVE);
     return boundaries[1] - boundaries[0];
@@ -574,7 +908,7 @@ public static double betweenSelectivity(KllFloatsSketch 
kll, float leftValue, fl
           "Selectivity for BETWEEN leftValue AND rightValue when the two 
values coincide is not supported, found: "
           + "leftValue = " + leftValue + " and rightValue = " + rightValue);
     }
-    return rangedSelectivity(kll, Math.nextDown(leftValue), 
Math.nextUp(rightValue));
+    return rangedSelectivity(kll, leftValue, Math.nextUp(rightValue));
   }
 
   /**
diff --git 
a/ql/src/test/org/apache/hadoop/hive/ql/optimizer/calcite/stats/TestFilterSelectivityEstimator.java
 
b/ql/src/test/org/apache/hadoop/hive/ql/optimizer/calcite/stats/TestFilterSelectivityEstimator.java
index 4255c756e07..28dc2e1ec34 100644
--- 
a/ql/src/test/org/apache/hadoop/hive/ql/optimizer/calcite/stats/TestFilterSelectivityEstimator.java
+++ 
b/ql/src/test/org/apache/hadoop/hive/ql/optimizer/calcite/stats/TestFilterSelectivityEstimator.java
@@ -17,7 +17,6 @@
  */
 package org.apache.hadoop.hive.ql.optimizer.calcite.stats;
 
-import com.google.common.collect.ImmutableList;
 import org.apache.calcite.jdbc.JavaTypeFactoryImpl;
 import org.apache.calcite.plan.RelOptCluster;
 import org.apache.calcite.plan.RelOptPlanner;
@@ -26,10 +25,15 @@
 import org.apache.calcite.rel.metadata.RelMetadataQuery;
 import org.apache.calcite.rel.type.RelDataType;
 import org.apache.calcite.rel.type.RelDataTypeFactory;
+import org.apache.calcite.rel.type.RelDataTypeField;
 import org.apache.calcite.rex.RexBuilder;
+import org.apache.calcite.rex.RexCall;
+import org.apache.calcite.rex.RexLiteral;
 import org.apache.calcite.rex.RexNode;
+import org.apache.calcite.rex.RexUtil;
 import org.apache.calcite.sql.fun.SqlStdOperatorTable;
 import org.apache.calcite.sql.type.SqlTypeName;
+import org.apache.calcite.sql.type.SqlTypeUtil;
 import org.apache.calcite.tools.RelBuilder;
 import org.apache.calcite.util.ImmutableBitSet;
 import org.apache.datasketches.kll.KllFloatsSketch;
@@ -43,6 +47,7 @@
 import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveTableScan;
 import org.apache.hadoop.hive.ql.parse.CalcitePlanner;
 import org.apache.hadoop.hive.ql.plan.ColStatistics;
+import org.apache.hadoop.hive.ql.udf.generic.GenericUDF;
 import org.junit.Assert;
 import org.junit.Before;
 import org.junit.BeforeClass;
@@ -51,24 +56,77 @@
 import org.mockito.Mock;
 import org.mockito.junit.MockitoJUnitRunner;
 
+import java.time.Instant;
+import java.time.LocalDate;
+import java.time.LocalTime;
+import java.time.ZoneOffset;
+import java.util.ArrayList;
 import java.util.Collections;
+import java.util.List;
 
+import static org.apache.calcite.sql.type.SqlTypeName.BIGINT;
+import static org.apache.calcite.sql.type.SqlTypeName.INTEGER;
+import static org.apache.calcite.sql.type.SqlTypeName.SMALLINT;
+import static org.apache.calcite.sql.type.SqlTypeName.TINYINT;
 import static 
org.apache.hadoop.hive.ql.optimizer.calcite.stats.FilterSelectivityEstimator.betweenSelectivity;
 import static 
org.apache.hadoop.hive.ql.optimizer.calcite.stats.FilterSelectivityEstimator.greaterThanOrEqualSelectivity;
 import static 
org.apache.hadoop.hive.ql.optimizer.calcite.stats.FilterSelectivityEstimator.greaterThanSelectivity;
 import static 
org.apache.hadoop.hive.ql.optimizer.calcite.stats.FilterSelectivityEstimator.isHistogramAvailable;
 import static 
org.apache.hadoop.hive.ql.optimizer.calcite.stats.FilterSelectivityEstimator.lessThanOrEqualSelectivity;
 import static 
org.apache.hadoop.hive.ql.optimizer.calcite.stats.FilterSelectivityEstimator.lessThanSelectivity;
+import static org.mockito.ArgumentMatchers.any;
 import static org.mockito.Mockito.doReturn;
+import static org.mockito.Mockito.never;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.when;
 
 @RunWith(MockitoJUnitRunner.class)
 public class TestFilterSelectivityEstimator {
 
   private static final float[] VALUES = { 1, 2, 2, 2, 2, 2, 2, 2, 3, 4, 5, 6, 
7 };
+  private static final float[] VALUES2 = {
+      // rounding for DECIMAL(3,1)
+      // -99.95f and its two predecessors and successors
+      -99.95001f, -99.950005f, -99.95f, -99.94999f, -99.94998f,
+      // some values
+      0f, 1f, 10f,
+      // rounding for DECIMAL(3,1)
+      // 99.95f and its two predecessors and successors
+      99.94998f, 99.94999f, 99.95f, 99.950005f, 99.95001f,
+      // 100f and its two predecessors and successors
+      99.999985f, 99.99999f, 100f, 100.00001f, 100.000015f,
+      // 100.05f and its two predecessors and successors
+      100.04999f, 100.049995f, 100.05f, 100.05001f, 100.05002f,
+      // some values
+      1_000f, 10_000f, 100_000f, 1_000_000f, 1e19f };
+  private static final float[] VALUES3 = {
+      // the closest floats that are CAST to the integer types, and one below 
and above the range
+      -9.223373E18f, -9.223372E18f, 9.223372E18f, 9.223373E18f,  // long
+      -2.147484E9f, -2.1474836E9f, 2.1474836E9f, 2.147484E9f, // integer
+      -32769.0f, -32768.996f, 32767.998f, 32768.0f, // short
+      -129f, -128.99998f, 127.99999f, 128.0f, // byte
+      // numbers for checking the rounding when casting to integer types
+      10f, 10.0001f, 10.9999f, 11f,
+      // corresponding negative values
+      -11f, -10.9999f, -10.0001f, -10f };
+
+  /**
+   * Both dates and timestamps are converted to epoch seconds.
+   * <p>
+   * See {@link 
org.apache.hadoop.hive.ql.udf.generic.GenericUDFToUnixTimeStamp#evaluate(GenericUDF.DeferredObject[])}.
+   */
+  private static final float[] VALUES_TIME = {
+      timestamp("2020-11-01"), timestamp("2020-11-02"), 
timestamp("2020-11-03"), timestamp("2020-11-04"),
+      timestamp("2020-11-05T11:23:45Z"), timestamp("2020-11-06"), 
timestamp("2020-11-07") };
+
   private static final KllFloatsSketch KLL = 
StatisticsTestUtils.createKll(VALUES);
-  private static final float DELTA = Float.MIN_VALUE;
+  private static final KllFloatsSketch KLL2 = 
StatisticsTestUtils.createKll(VALUES2);
+  private static final KllFloatsSketch KLL3 = 
StatisticsTestUtils.createKll(VALUES3);
+  private static final KllFloatsSketch KLL_TIME = 
StatisticsTestUtils.createKll(VALUES_TIME);
+  private static final float DELTA = 1e-7f;
   private static final RexBuilder REX_BUILDER = new RexBuilder(new 
JavaTypeFactoryImpl(new HiveTypeSystemImpl()));
   private static final RelDataTypeFactory TYPE_FACTORY = 
REX_BUILDER.getTypeFactory();
+
   private static RelOptCluster relOptCluster;
   private static RexNode intMinus1;
   private static RexNode int0;
@@ -85,7 +143,6 @@ public class TestFilterSelectivityEstimator {
   private static RexNode inputRef0;
   private static RexNode boolFalse;
   private static RexNode boolTrue;
-  private static ColStatistics stats;
 
   @Mock
   private RelOptSchema schemaMock;
@@ -94,12 +151,14 @@ public class TestFilterSelectivityEstimator {
   @Mock
   private RelMetadataQuery mq;
 
-  private HiveTableScan tableScan;
+  private ColStatistics stats;
   private RelNode scan;
+  private RexNode currentInputRef;
+  private int currentValuesSize;
 
   @BeforeClass
   public static void beforeClass() {
-    RelDataType integerType = TYPE_FACTORY.createSqlType(SqlTypeName.INTEGER);
+    RelDataType integerType = TYPE_FACTORY.createSqlType(INTEGER);
     intMinus1 = REX_BUILDER.makeLiteral(-1, integerType, true);
     int0 = REX_BUILDER.makeLiteral(0, integerType, true);
     int1 = REX_BUILDER.makeLiteral(1, integerType, true);
@@ -113,25 +172,61 @@ public static void beforeClass() {
     int11 = REX_BUILDER.makeLiteral(11, integerType, true);
     boolFalse = REX_BUILDER.makeLiteral(false, 
TYPE_FACTORY.createSqlType(SqlTypeName.BOOLEAN), true);
     boolTrue = REX_BUILDER.makeLiteral(true, 
TYPE_FACTORY.createSqlType(SqlTypeName.BOOLEAN), true);
-    tableType = TYPE_FACTORY.createStructType(ImmutableList.of(integerType), 
ImmutableList.of("f1"));
+    RelDataTypeFactory.Builder b = new 
RelDataTypeFactory.Builder(TYPE_FACTORY);
+    b.add("f_numeric", decimalType(38, 25));
+    b.add("f_decimal10s3", decimalType(10, 3));
+    b.add("f_float", TYPE_FACTORY.createSqlType(SqlTypeName.FLOAT));
+    b.add("f_double", TYPE_FACTORY.createSqlType(SqlTypeName.DOUBLE));
+    b.add("f_tinyint", TYPE_FACTORY.createSqlType(TINYINT));
+    b.add("f_smallint", TYPE_FACTORY.createSqlType(SMALLINT));
+    b.add("f_integer", integerType);
+    b.add("f_bigint", TYPE_FACTORY.createSqlType(BIGINT));
+    b.add("f_timestamp", SqlTypeName.TIMESTAMP);
+    b.add("f_date", SqlTypeName.DATE).build();
+    tableType = b.build();
 
     RelOptPlanner planner = CalcitePlanner.createPlanner(new HiveConf());
     relOptCluster = RelOptCluster.create(planner, REX_BUILDER);
+  }
 
-    stats = new ColStatistics();
-    stats.setHistogram(KLL.toByteArray());
+  private static ColStatistics.Range rangeOf(float[] values) {
+    float min = Float.MAX_VALUE, max = -Float.MAX_VALUE;
+    for (float v : values) {
+      min = Math.min(min, v);
+      max = Math.max(max, v);
+    }
+    return new ColStatistics.Range(min, max);
   }
 
   @Before
   public void before() {
+    currentValuesSize = VALUES.length;
     doReturn(tableType).when(tableMock).getRowType();
-    doReturn((double) VALUES.length).when(tableMock).getRowCount();
+    when(tableMock.getRowCount()).thenAnswer(a -> (double) currentValuesSize);
 
     RelBuilder relBuilder = 
HiveRelFactories.HIVE_BUILDER.create(relOptCluster, schemaMock);
-    tableScan = new HiveTableScan(relOptCluster, 
relOptCluster.traitSetOf(HiveRelNode.CONVENTION),
-        tableMock, "table", null, false, false);
+    HiveTableScan tableScan =
+        new HiveTableScan(relOptCluster, 
relOptCluster.traitSetOf(HiveRelNode.CONVENTION), tableMock, "table", null,
+            false, false);
     scan = relBuilder.push(tableScan).build();
     inputRef0 = REX_BUILDER.makeInputRef(scan, 0);
+    currentInputRef = inputRef0;
+
+    stats = new ColStatistics();
+    stats.setHistogram(KLL.toByteArray());
+    stats.setRange(rangeOf(VALUES));
+  }
+
+  /**
+   * Note: call this method only at the beginning of a test method.
+   */
+  private void useFieldWithValues(String fieldname, float[] values, 
KllFloatsSketch sketch) {
+    currentValuesSize = values.length;
+    stats.setHistogram(sketch.toByteArray());
+    stats.setRange(rangeOf(values));
+    int fieldIndex = scan.getRowType().getFieldNames().indexOf(fieldname);
+    currentInputRef = REX_BUILDER.makeInputRef(scan, fieldIndex);
+    
doReturn(Collections.singletonList(stats)).when(tableMock).getColStat(Collections.singletonList(fieldIndex));
   }
 
   @Test
@@ -420,7 +515,7 @@ public void 
testComputeRangePredicateSelectivityBetweenLeftLowerThanRight() {
 
   @Test
   public void testComputeRangePredicateSelectivityBetweenLeftEqualsRight() {
-    
doReturn(Collections.singletonList(stats)).when(tableMock).getColStat(Collections.singletonList(0));
+    verify(tableMock, never()).getColStat(any());
     doReturn(10.0).when(mq).getDistinctRowCount(scan, ImmutableBitSet.of(0), 
REX_BUILDER.makeLiteral(true));
     RexNode filter = REX_BUILDER.makeCall(HiveBetween.INSTANCE, boolFalse, 
inputRef0, int3, int3);
     FilterSelectivityEstimator estimator = new 
FilterSelectivityEstimator(scan, mq);
@@ -454,7 +549,7 @@ public void 
testComputeRangePredicateSelectivityNotBetweenRightLowerThanLeft() {
 
   @Test
   public void testComputeRangePredicateSelectivityNotBetweenLeftEqualsRight() {
-    
doReturn(Collections.singletonList(stats)).when(tableMock).getColStat(Collections.singletonList(0));
+    verify(tableMock, never()).getColStat(any());
     RexNode filter = REX_BUILDER.makeCall(HiveBetween.INSTANCE, boolTrue, 
inputRef0, int3, int3);
     FilterSelectivityEstimator estimator = new 
FilterSelectivityEstimator(scan, mq);
     Assert.assertEquals(1, estimator.estimateSelectivity(filter), DELTA);
@@ -511,6 +606,383 @@ public void 
testComputeRangePredicateSelectivityNotBetweenWithNULLS() {
     
doReturn(Collections.singletonList(stats)).when(tableMock).getColStat(Collections.singletonList(0));
     RexNode filter = REX_BUILDER.makeCall(HiveBetween.INSTANCE, boolTrue, 
inputRef0, int1, int3);
     FilterSelectivityEstimator estimator = new 
FilterSelectivityEstimator(scan, mq);
-    Assert.assertEquals(0.55, estimator.estimateSelectivity(filter), DELTA);
+    // only the values 4, 5, 6, 7 fulfill the condition NOT BETWEEN 1 AND 3
+    // (the NULL values do not fulfill the condition)
+    Assert.assertEquals(0.2, estimator.estimateSelectivity(filter), DELTA);
+  }
+
+  @Test
+  public void testCastMatrix() {
+    // checks many possible combinations of types
+    List<RelDataTypeField> fields = new ArrayList<>(tableType.getFieldList());
+    fields.removeIf(f -> SqlTypeUtil.isDatetime(f.getType()));
+    for (var srcField : fields) {
+      useFieldWithValues(srcField.getName(), VALUES, KLL);
+
+      for (var tgt : fields) {
+        try {
+          RexNode expr = cast(srcField.getName(), tgt.getType());
+          checkSelectivity(3 / 13.f, ge(cast(srcField.getName(), 
tgt.getType()), int5));
+          checkSelectivity(10 / 13.f, lt(cast(srcField.getName(), 
tgt.getType()), int5));
+          checkSelectivity(2 / 13.f, gt(cast(srcField.getName(), 
tgt.getType()), int5));
+          checkSelectivity(11 / 13.f, le(cast(srcField.getName(), 
tgt.getType()), int5));
+
+          checkSelectivity(12 / 13f, ge(cast(srcField.getName(), 
tgt.getType()), int2));
+          checkSelectivity(1 / 13f, lt(cast(srcField.getName(), 
tgt.getType()), int2));
+          checkSelectivity(5 / 13f, gt(cast(srcField.getName(), 
tgt.getType()), int2));
+          checkSelectivity(8 / 13f, le(cast(srcField.getName(), 
tgt.getType()), int2));
+
+          checkBetweenSelectivity(13, VALUES.length, VALUES.length, expr, 0, 
10);
+          checkBetweenSelectivity(3, VALUES.length, VALUES.length, expr, 5, 7);
+          checkBetweenSelectivity(8, VALUES.length, VALUES.length, expr, 1, 2);
+        } catch (Throwable e) {
+          throw new AssertionError("Error when casting from " + 
srcField.getType() + " to " + tgt.getType(), e);
+        }
+      }
+    }
+  }
+
+  @Test
+  public void testRangePredicateCastIntegerValuesOutsideTypeRange() {
+    // use VALUES2, even if the tested types cannot represent its values
+    // we're only interested in whether the cast to a smaller integer type 
results in the default selectivity
+    useFieldWithValues("f_tinyint", VALUES2, KLL2);
+    checkSelectivity(16 / 28.f, ge(cast("f_tinyint", TINYINT), int5));
+    checkSelectivity(18 / 28.f, ge(cast("f_tinyint", SMALLINT), int5));
+    checkSelectivity(20 / 28.f, ge(cast("f_tinyint", INTEGER), int5));
+    checkSelectivity(20 / 28.f, ge(cast("f_tinyint", BIGINT), int5));
+
+    useFieldWithValues("f_smallint", VALUES2, KLL2);
+    checkSelectivity(1 / 3.f, ge(cast("f_smallint", TINYINT), int5));
+    checkSelectivity(18 / 28.f, ge(cast("f_smallint", SMALLINT), int5));
+    checkSelectivity(20 / 28.f, ge(cast("f_smallint", INTEGER), int5));
+    checkSelectivity(20 / 28.f, ge(cast("f_smallint", BIGINT), int5));
+
+    useFieldWithValues("f_integer", VALUES2, KLL2);
+    checkSelectivity(1 / 3.f, ge(cast("f_integer", TINYINT), int5));
+    checkSelectivity(1 / 3.f, ge(cast("f_integer", SMALLINT), int5));
+    checkSelectivity(20 / 28.f, ge(cast("f_integer", INTEGER), int5));
+    checkSelectivity(20 / 28.f, ge(cast("f_integer", BIGINT), int5));
+
+    useFieldWithValues("f_bigint", VALUES2, KLL2);
+    checkSelectivity(1 / 3.f, ge(cast("f_bigint", TINYINT), int5));
+    checkSelectivity(1 / 3.f, ge(cast("f_bigint", SMALLINT), int5));
+    checkSelectivity(1 / 3.f, ge(cast("f_bigint", INTEGER), int5));
+    checkSelectivity(20 / 28.f, ge(cast("f_bigint", BIGINT), int5));
+  }
+
+  @Test
+  public void testRangePredicateWithCast() {
+    useFieldWithValues("f_numeric", VALUES2, KLL2);
+    RelDataType decimal3s1 = decimalType(3, 1);
+    checkSelectivity(4 / 28.f, ge(cast("f_numeric", decimal3s1), 
literalFloat(1)));
+
+    // values from -99.94999 to 99.94999 (both inclusive)
+    checkSelectivity(7 / 28.f, lt(cast("f_numeric", decimal3s1), 
literalFloat(100)));
+    checkSelectivity(7 / 28.f, le(cast("f_numeric", decimal3s1), 
literalFloat(100)));
+    checkSelectivity(0 / 28.f, gt(cast("f_numeric", decimal3s1), 
literalFloat(100)));
+    checkSelectivity(0 / 28.f, ge(cast("f_numeric", decimal3s1), 
literalFloat(100)));
+
+    RelDataType decimal4s1 = decimalType(4, 1);
+    checkSelectivity(10 / 28.f, lt(cast("f_numeric", decimal4s1), 
literalFloat(100)));
+    checkSelectivity(20 / 28.f, le(cast("f_numeric", decimal4s1), 
literalFloat(100)));
+    checkSelectivity(3 / 28.f, gt(cast("f_numeric", decimal4s1), 
literalFloat(100)));
+    checkSelectivity(13 / 28.f, ge(cast("f_numeric", decimal4s1), 
literalFloat(100)));
+
+    RelDataType decimal2s1 = decimalType(2, 1);
+    checkSelectivity(2 / 28.f, lt(cast("f_numeric", decimal2s1), 
literalFloat(100)));
+    checkSelectivity(2 / 28.f, le(cast("f_numeric", decimal2s1), 
literalFloat(100)));
+    checkSelectivity(0 / 28.f, gt(cast("f_numeric", decimal2s1), 
literalFloat(100)));
+    checkSelectivity(0 / 28.f, ge(cast("f_numeric", decimal2s1), 
literalFloat(100)));
+
+    // expected: 100_000f
+    RelDataType decimal7s1 = decimalType(7, 1);
+    checkSelectivity(1 / 28.f, gt(cast("f_numeric", decimal7s1), 
literalFloat(10000)));
+
+    // expected: 10_000f, 100_000f, because CAST(1_000_000 AS DECIMAL(7,1)) = 
NULL, and similar for even larger values
+    checkSelectivity(2 / 28.f, ge(cast("f_numeric", decimal7s1), 
literalFloat(9999)));
+    checkSelectivity(2 / 28.f, ge(cast("f_numeric", decimal7s1), 
literalFloat(10000)));
+
+    // expected: 100_000f
+    checkSelectivity(1 / 28.f, gt(cast("f_numeric", decimal7s1), 
literalFloat(10000)));
+    checkSelectivity(1 / 28.f, gt(cast("f_numeric", decimal7s1), 
literalFloat(10001)));
+
+    // expected 1f, 10f, 99.94998f, 99.94999f
+    checkSelectivity(4 / 28.f, ge(cast("f_numeric", decimal3s1), 
literalFloat(1)));
+    checkSelectivity(3 / 28.f, gt(cast("f_numeric", decimal3s1), 
literalFloat(1)));
+    // expected -99.94999f, -99.94998f, 0f, 1f
+    checkSelectivity(4 / 28.f, le(cast("f_numeric", decimal3s1), 
literalFloat(1)));
+    checkSelectivity(3 / 28.f, lt(cast("f_numeric", decimal3s1), 
literalFloat(1)));
+  }
+
+  private void checkTimeFieldOnMidnightTimestamps(RexNode field) {
+    // note: use only values from VALUES_TIME that specify a date without 
hh:mm:ss!
+    checkSelectivity(7 / 7.f, ge(field, literalTimestamp("2020-11-01")));
+    checkSelectivity(5 / 7.f, ge(field, literalTimestamp("2020-11-03")));
+    checkSelectivity(1 / 7.f, ge(field, literalTimestamp("2020-11-07")));
+
+    checkSelectivity(6 / 7.f, gt(field, literalTimestamp("2020-11-01")));
+    checkSelectivity(4 / 7.f, gt(field, literalTimestamp("2020-11-03")));
+    checkSelectivity(0 / 7.f, gt(field, literalTimestamp("2020-11-07")));
+
+    checkSelectivity(1 / 7.f, le(field, literalTimestamp("2020-11-01")));
+    checkSelectivity(3 / 7.f, le(field, literalTimestamp("2020-11-03")));
+    checkSelectivity(7 / 7.f, le(field, literalTimestamp("2020-11-07")));
+
+    checkSelectivity(0 / 7.f, lt(field, literalTimestamp("2020-11-01")));
+    checkSelectivity(2 / 7.f, lt(field, literalTimestamp("2020-11-03")));
+    checkSelectivity(6 / 7.f, lt(field, literalTimestamp("2020-11-07")));
+
+    checkBetweenSelectivity(2, 7, 7, field, literalTimestamp("2020-11-01"), 
literalTimestamp("2020-11-02"));
+    checkBetweenSelectivity(4, 7, 7, field, literalTimestamp("2020-11-03"), 
literalTimestamp("2020-11-06"));
+  }
+
+  private void checkTimeFieldOnIntraDayTimestamps(RexNode field) {
+    checkSelectivity(3 / 7.f, ge(field, 
literalTimestamp("2020-11-05T11:23:45Z")));
+    checkSelectivity(2 / 7.f, gt(field, 
literalTimestamp("2020-11-05T11:23:45Z")));
+    checkSelectivity(5 / 7.f, le(field, 
literalTimestamp("2020-11-05T11:23:45Z")));
+    checkSelectivity(4 / 7.f, lt(field, 
literalTimestamp("2020-11-05T11:23:45Z")));
+
+    checkBetweenSelectivity(3, 7, 7, field, literalTimestamp("2020-11-03"), 
literalTimestamp("2020-11-05T11:23:45Z"));
+    // note: the same timestamp with seconds ":44" maps to the same float, so 
use ":43" to get a smaller boundary
+    checkBetweenSelectivity(2, 7, 7, field, literalTimestamp("2020-11-03"), 
literalTimestamp("2020-11-05T11:23:43Z"));
+  }
+
+  @Test
+  public void testRangePredicateOnTimestamp() {
+    useFieldWithValues("f_timestamp", VALUES_TIME, KLL_TIME);
+    checkTimeFieldOnMidnightTimestamps(currentInputRef);
+    checkTimeFieldOnIntraDayTimestamps(currentInputRef);
+  }
+
+  @Test
+  public void testRangePredicateOnTimestampWithCast() {
+    useFieldWithValues("f_timestamp", VALUES_TIME, KLL_TIME);
+    RexNode expr1 = cast("f_timestamp", SqlTypeName.DATE);
+    checkTimeFieldOnMidnightTimestamps(expr1);
+    checkTimeFieldOnIntraDayTimestamps(expr1);
+
+    RexNode expr2 = cast("f_timestamp", SqlTypeName.TIMESTAMP);
+    checkTimeFieldOnMidnightTimestamps(expr2);
+    checkTimeFieldOnIntraDayTimestamps(expr2);
+  }
+
+  @Test
+  public void testRangePredicateOnDate() {
+    useFieldWithValues("f_date", VALUES_TIME, KLL_TIME);
+    checkTimeFieldOnMidnightTimestamps(currentInputRef);
+
+    // it does not make sense to compare with "2020-11-05T11:23:45Z",
+    // as that value would not be stored as-is in a date column, but as 
"2020-11-05" instead
+  }
+
+  @Test
+  public void testRangePredicateOnDateWithCast() {
+    useFieldWithValues("f_date", VALUES_TIME, KLL_TIME);
+    checkTimeFieldOnMidnightTimestamps(cast("f_date", SqlTypeName.DATE));
+    checkTimeFieldOnMidnightTimestamps(cast("f_date", SqlTypeName.TIMESTAMP));
+
+    // it does not make sense to compare with "2020-11-05T11:23:45Z",
+    // as that value would not be stored as-is in a date column, but as 
"2020-11-05" instead
+  }
+
+  @Test
+  public void testBetweenWithCastToTinyInt() {
+    useFieldWithValues("f_numeric", VALUES3, KLL3);
+    float total = VALUES3.length;
+    float universe = 10; // the number of values that "survive" the cast
+    RexNode cast = cast("f_numeric", TINYINT);
+    checkBetweenSelectivity(5, universe, total, cast, 0, 1e20f);
+    checkBetweenSelectivity(5, universe, total, cast, -1e20f, 0);
+    checkBetweenSelectivity(0, universe, total, cast, 100f, 0f);
+
+    // check rounding of positive numbers
+    checkBetweenSelectivity(3, universe, total, cast, 0, 10);
+    checkBetweenSelectivity(4, universe, total, cast, 0, 11);
+    checkBetweenSelectivity(4, universe, total, cast, 10, 20);
+    checkBetweenSelectivity(1, universe, total, cast, 11, 20);
+
+    // check rounding of negative numbers
+    checkBetweenSelectivity(4, universe, total, cast, -20, -10);
+    checkBetweenSelectivity(1, universe, total, cast, -20, -11);
+    checkBetweenSelectivity(3, universe, total, cast, -10, 0);
+    checkBetweenSelectivity(4, universe, total, cast, -11, 0);
+  }
+
+  @Test
+  public void testBetweenWithCastToSmallInt() {
+    useFieldWithValues("f_numeric", VALUES3, KLL3);
+    float total = VALUES3.length;
+    float universe = 14; // the number of values that "survive" the cast
+    RexNode cast = cast("f_numeric", SMALLINT);
+    checkBetweenSelectivity(7, universe, total, cast, 0, 1e20f);
+    checkBetweenSelectivity(7, universe, total, cast, -1e20f, 0);
+    checkBetweenSelectivity(0, universe, total, cast, 100f, 0f);
+  }
+
+  @Test
+  public void testBetweenWithCastToInteger() {
+    useFieldWithValues("f_numeric", VALUES3, KLL3);
+    float total = VALUES3.length;
+    float universe = 18; // the number of values that "survive" the cast
+    RexNode cast = cast("f_numeric", INTEGER);
+    checkBetweenSelectivity(9, universe, total, cast, 0, 1e20f);
+    checkBetweenSelectivity(9, universe, total, cast, -1e20f, 0);
+    checkBetweenSelectivity(0, universe, total, cast, 100f, 0f);
+  }
+
+  @Test
+  public void testBetweenWithCastToBigInt() {
+    useFieldWithValues("f_numeric", VALUES3, KLL3);
+    float total = VALUES3.length;
+    float universe = 22; // the number of values that "survive" the cast
+    RexNode cast = cast("f_numeric", BIGINT);
+    checkBetweenSelectivity(11, universe, total, cast, 0, 1e20f);
+    checkBetweenSelectivity(11, universe, total, cast, -1e20f, 0);
+    checkBetweenSelectivity(0, universe, total, cast, 100f, 0f);
+  }
+
+  @Test
+  public void testBetweenWithCastToSmallInt2() {
+    useFieldWithValues("f_numeric", VALUES2, KLL2);
+    float total = VALUES2.length;
+    float universe = 23; // the number of values that "survive" the cast
+    RexNode cast = cast("f_numeric", TINYINT);
+    checkBetweenSelectivity(8, universe, total, cast, 100f, 1000f);
+    checkBetweenSelectivity(17, universe, total, cast, 1f, 100f);
+    checkBetweenSelectivity(0, universe, total, cast, 100f, 0f);
+  }
+
+  @Test
+  public void testBetweenWithCastToDecimal2s1() {
+    useFieldWithValues("f_numeric", VALUES2, KLL2);
+    float total = VALUES2.length;
+    float universe = 2; // the number of values that "survive" the cast
+    RexNode cast = REX_BUILDER.makeCast(decimalType(2, 1), inputRef0);
+    checkBetweenSelectivity(0, universe, total, cast, 100f, 1000f);
+    checkBetweenSelectivity(1, universe, total, cast, 1f, 100f);
+    checkBetweenSelectivity(0, universe, total, cast, 100f, 0f);
+  }
+
+  @Test
+  public void testBetweenWithCastToDecimal3s1() {
+    useFieldWithValues("f_numeric", VALUES2, KLL2);
+    float total = VALUES2.length;
+    float universe = 7; // the number of values that "survive" the cast
+    RexNode cast = REX_BUILDER.makeCast(decimalType(3, 1), inputRef0);
+    checkBetweenSelectivity(0, universe, total, cast, 100f, 1000f);
+    checkBetweenSelectivity(4, universe, total, cast, 1f, 100f);
+    checkBetweenSelectivity(0, universe, total, cast, 100f, 0f);
+  }
+
+  @Test
+  public void testBetweenWithCastToDecimal4s1() {
+    useFieldWithValues("f_numeric", VALUES2, KLL2);
+    float total = VALUES2.length;
+    float universe = 23; // the number of values that "survive" the cast
+    RexNode cast = REX_BUILDER.makeCast(decimalType(4, 1), inputRef0);
+    // the values between -999.94999... and 999.94999... (both inclusive) pass 
through the cast
+    // the values between 99.95 and 100 are rounded up to 100, so they fulfill 
the BETWEEN
+    checkBetweenSelectivity(13, universe, total, cast, 100, 1000);
+    checkBetweenSelectivity(14, universe, total, cast, 1f, 100f);
+    checkBetweenSelectivity(0, universe, total, cast, 100f, 0f);
+  }
+
+  @Test
+  public void testBetweenWithCastToDecimal7s1() {
+    useFieldWithValues("f_numeric", VALUES2, KLL2);
+    float total = VALUES2.length;
+    float universe = 26; // the number of values that "survive" the cast
+    RexNode cast = REX_BUILDER.makeCast(decimalType(7, 1), inputRef0);
+    checkBetweenSelectivity(14, universe, total, cast, 100, 1000);
+    checkBetweenSelectivity(14, universe, total, cast, 1f, 100f);
+    checkBetweenSelectivity(0, universe, total, cast, 100f, 0f);
+  }
+
+  private void checkSelectivity(float expectedSelectivity, RexNode filter) {
+    FilterSelectivityEstimator estimator = new 
FilterSelectivityEstimator(scan, mq);
+    Assert.assertEquals(filter.toString(), expectedSelectivity, 
estimator.estimateSelectivity(filter), DELTA);
+
+    // convert "col OP value" to "value INVERSE_OP col", and check it
+    RexNode inverted = RexUtil.invert(REX_BUILDER, (RexCall) filter);
+    if (inverted != null) {
+      Assert.assertEquals(filter.toString(), expectedSelectivity, 
estimator.estimateSelectivity(inverted), DELTA);
+    }
+  }
+
+  private void checkBetweenSelectivity(float expectedEntries, float universe, 
float total, RexNode value, float lower,
+      float upper) {
+    checkBetweenSelectivity(expectedEntries, universe, total, value, 
literalFloat(lower), literalFloat(upper));
+  }
+
+  private void checkBetweenSelectivity(float expectedEntries, float universe, 
float total, RexNode value, RexNode lower,
+      RexNode upper) {
+    RexNode betweenFilter = REX_BUILDER.makeCall(HiveBetween.INSTANCE, 
boolFalse, value, lower, upper);
+    FilterSelectivityEstimator estimator = new 
FilterSelectivityEstimator(scan, mq);
+    String between = "BETWEEN " + lower + " AND " + upper;
+    float expectedSelectivity = expectedEntries / total;
+    String message = between + ": calcite filter " + betweenFilter.toString();
+    Assert.assertEquals(message, expectedSelectivity, 
estimator.estimateSelectivity(betweenFilter), DELTA);
+
+    // invert the filter to a NOT BETWEEN
+    RexNode invBetween = REX_BUILDER.makeCall(HiveBetween.INSTANCE, boolTrue, 
value, lower, upper);
+    String invMessage = "NOT " + between + ": calcite filter " + 
invBetween.toString();
+    float invExpectedSelectivity = (universe - expectedEntries) / total;
+    Assert.assertEquals(invMessage, invExpectedSelectivity, 
estimator.estimateSelectivity(invBetween), DELTA);
+  }
+
+
+  private RexNode cast(String fieldname, SqlTypeName typeName) {
+    return cast(fieldname, type(typeName));
+  }
+
+  private RexNode cast(String fieldname, RelDataType type) {
+    int fieldIndex = scan.getRowType().getFieldNames().indexOf(fieldname);
+    RexNode column = REX_BUILDER.makeInputRef(scan, fieldIndex);
+    return REX_BUILDER.makeCast(type, column);
+  }
+
+  private RexNode ge(RexNode expr, RexNode value) {
+    return REX_BUILDER.makeCall(SqlStdOperatorTable.GREATER_THAN_OR_EQUAL, 
expr, value);
+  }
+
+  private RexNode gt(RexNode expr, RexNode value) {
+    return REX_BUILDER.makeCall(SqlStdOperatorTable.GREATER_THAN, expr, value);
+  }
+
+  private RexNode le(RexNode expr, RexNode value) {
+    return REX_BUILDER.makeCall(SqlStdOperatorTable.LESS_THAN_OR_EQUAL, expr, 
value);
+  }
+
+  private RexNode lt(RexNode expr, RexNode value) {
+    return REX_BUILDER.makeCall(SqlStdOperatorTable.LESS_THAN, expr, value);
+  }
+
+  private static RelDataType type(SqlTypeName typeName) {
+    return REX_BUILDER.getTypeFactory().createSqlType(typeName);
+  }
+
+  private static RelDataType decimalType(int precision, int scale) {
+    return REX_BUILDER.getTypeFactory().createSqlType(SqlTypeName.DECIMAL, 
precision, scale);
+  }
+
+  private static RexLiteral literalTimestamp(String timestamp) {
+    return REX_BUILDER.makeLiteral(timestampMillis(timestamp),
+        REX_BUILDER.getTypeFactory().createSqlType(SqlTypeName.TIMESTAMP));
+  }
+
+  private RexNode literalFloat(float f) {
+    return REX_BUILDER.makeLiteral(f, type(SqlTypeName.FLOAT));
+  }
+
+  private static long timestampMillis(String timestamp) {
+    if (!timestamp.contains(":")) {
+      return LocalDate.parse(timestamp).toEpochSecond(LocalTime.MIDNIGHT, 
ZoneOffset.UTC) * 1000;
+    }
+    return Instant.parse(timestamp).toEpochMilli();
+  }
+
+  private static long timestamp(String timestamp) {
+    return timestampMillis(timestamp) / 1000;
   }
 }

Reply via email to