bersprockets commented on a change in pull request #30207:
URL: https://github.com/apache/spark/pull/30207#discussion_r516069825



##########
File path: 
sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala
##########
@@ -726,33 +726,79 @@ private[client] class Shim_v0_13 extends Shim_v0_12 {
     val useAdvanced = SQLConf.get.advancedPartitionPredicatePushdownEnabled
 
     object ExtractAttribute {
-      def unapply(expr: Expression): Option[Attribute] = {
+      def unapply(expr: Expression): Option[(Attribute, DataType)] = {
         expr match {
-          case attr: Attribute => Some(attr)
+          case attr: Attribute => Some(attr, attr.dataType)
           case Cast(child @ AtomicType(), dt: AtomicType, _)
               if Cast.canUpCast(child.dataType.asInstanceOf[AtomicType], dt) 
=> unapply(child)
           case _ => None
         }
       }
     }
 
+    def compatibleTypes(dt1: Any, dt2: Any): Boolean =
+      (dt1, dt2) match {
+        case (_: IntegralType, _: IntegralType) => true
+        case (_: StringType, _: StringType) => true
+        case _ => false
+      }
+
+    def compatibleTypesIn(dt1: Any, dts: Seq[Any]): Boolean = {
+      dts.forall(compatibleTypes(dt1, _))
+    }
+
+    def fixValue(quotedValue: String, desiredType: DataType): Option[Any] = 
try {
+      val value = quotedValue.init.tail // remove leading and trailing quotes
+      desiredType match {
+        case LongType =>
+          Some(value.toLong)
+        case IntegerType =>
+          Some(value.toInt)
+        case ShortType =>
+          Some(value.toShort)
+        case ByteType =>
+          Some(value.toByte)
+      }
+    } catch {
+      case _: NumberFormatException =>
+        None
+    }
+
     def convert(expr: Expression): Option[String] = expr match {
-      case In(ExtractAttribute(SupportedAttribute(name)), 
ExtractableLiterals(values))
-          if useAdvanced =>
+      case In(ExtractAttribute(SupportedAttribute(name), dt1), 
ExtractableLiterals(valsAndDts))
+          if useAdvanced && compatibleTypesIn(dt1, valsAndDts.map(_._2)) =>
+        val values = valsAndDts.map(_._1)
         Some(convertInToOr(name, values))
 
-      case InSet(ExtractAttribute(SupportedAttribute(name)), 
ExtractableValues(values))
-          if useAdvanced =>
+      case InSet(ExtractAttribute(SupportedAttribute(name), dt1), 
ExtractableValues(valsAndDts))
+          if useAdvanced && compatibleTypesIn(dt1, valsAndDts.map(_._2)) =>
+        val values = valsAndDts.map(_._1)
         Some(convertInToOr(name, values))
 
       case op @ SpecialBinaryComparison(
-          ExtractAttribute(SupportedAttribute(name)), 
ExtractableLiteral(value)) =>
+          ExtractAttribute(SupportedAttribute(name), dt1), 
ExtractableLiteral(value, dt2))
+          if compatibleTypes(dt1, dt2) =>
         Some(s"$name ${op.symbol} $value")
 
       case op @ SpecialBinaryComparison(
-          ExtractableLiteral(value), 
ExtractAttribute(SupportedAttribute(name))) =>
+          ExtractAttribute(SupportedAttribute(name), dt1), 
ExtractableLiteral(rawValue, dt2))
+          if dt1.isInstanceOf[IntegralType] && dt2.isInstanceOf[StringType] =>
+        fixValue(rawValue, dt1).map { value =>
+          s"$name ${op.symbol} $value"
+        }

Review comment:
       I don't have an equivalent "attempt to correct" for In and Inset, just 
for binary comparisons. In the case of In and Inset, if the datatypes are not 
compatible, I just drop the filter (which is what would have happened before 
SPARK-22384)

##########
File path: 
sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala
##########
@@ -726,33 +726,79 @@ private[client] class Shim_v0_13 extends Shim_v0_12 {
     val useAdvanced = SQLConf.get.advancedPartitionPredicatePushdownEnabled
 
     object ExtractAttribute {
-      def unapply(expr: Expression): Option[Attribute] = {
+      def unapply(expr: Expression): Option[(Attribute, DataType)] = {
         expr match {
-          case attr: Attribute => Some(attr)
+          case attr: Attribute => Some(attr, attr.dataType)
           case Cast(child @ AtomicType(), dt: AtomicType, _)
               if Cast.canUpCast(child.dataType.asInstanceOf[AtomicType], dt) 
=> unapply(child)
           case _ => None
         }
       }
     }
 
+    def compatibleTypes(dt1: Any, dt2: Any): Boolean =
+      (dt1, dt2) match {
+        case (_: IntegralType, _: IntegralType) => true
+        case (_: StringType, _: StringType) => true
+        case _ => false
+      }
+
+    def compatibleTypesIn(dt1: Any, dts: Seq[Any]): Boolean = {
+      dts.forall(compatibleTypes(dt1, _))
+    }
+
+    def fixValue(quotedValue: String, desiredType: DataType): Option[Any] = 
try {
+      val value = quotedValue.init.tail // remove leading and trailing quotes
+      desiredType match {
+        case LongType =>
+          Some(value.toLong)
+        case IntegerType =>
+          Some(value.toInt)
+        case ShortType =>
+          Some(value.toShort)
+        case ByteType =>
+          Some(value.toByte)
+      }
+    } catch {
+      case _: NumberFormatException =>
+        None
+    }
+
     def convert(expr: Expression): Option[String] = expr match {
-      case In(ExtractAttribute(SupportedAttribute(name)), 
ExtractableLiterals(values))
-          if useAdvanced =>
+      case In(ExtractAttribute(SupportedAttribute(name), dt1), 
ExtractableLiterals(valsAndDts))
+          if useAdvanced && compatibleTypesIn(dt1, valsAndDts.map(_._2)) =>
+        val values = valsAndDts.map(_._1)
         Some(convertInToOr(name, values))
 
-      case InSet(ExtractAttribute(SupportedAttribute(name)), 
ExtractableValues(values))
-          if useAdvanced =>
+      case InSet(ExtractAttribute(SupportedAttribute(name), dt1), 
ExtractableValues(valsAndDts))
+          if useAdvanced && compatibleTypesIn(dt1, valsAndDts.map(_._2)) =>
+        val values = valsAndDts.map(_._1)
         Some(convertInToOr(name, values))
 
       case op @ SpecialBinaryComparison(
-          ExtractAttribute(SupportedAttribute(name)), 
ExtractableLiteral(value)) =>
+          ExtractAttribute(SupportedAttribute(name), dt1), 
ExtractableLiteral(value, dt2))
+          if compatibleTypes(dt1, dt2) =>
         Some(s"$name ${op.symbol} $value")
 
       case op @ SpecialBinaryComparison(
-          ExtractableLiteral(value), 
ExtractAttribute(SupportedAttribute(name))) =>
+          ExtractAttribute(SupportedAttribute(name), dt1), 
ExtractableLiteral(rawValue, dt2))
+          if dt1.isInstanceOf[IntegralType] && dt2.isInstanceOf[StringType] =>
+        fixValue(rawValue, dt1).map { value =>

Review comment:
       Yes, it should probably ignore any literal strings with leading zeros.

##########
File path: 
sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala
##########
@@ -726,33 +726,79 @@ private[client] class Shim_v0_13 extends Shim_v0_12 {
     val useAdvanced = SQLConf.get.advancedPartitionPredicatePushdownEnabled
 
     object ExtractAttribute {
-      def unapply(expr: Expression): Option[Attribute] = {
+      def unapply(expr: Expression): Option[(Attribute, DataType)] = {
         expr match {
-          case attr: Attribute => Some(attr)
+          case attr: Attribute => Some(attr, attr.dataType)
           case Cast(child @ AtomicType(), dt: AtomicType, _)
               if Cast.canUpCast(child.dataType.asInstanceOf[AtomicType], dt) 
=> unapply(child)
           case _ => None
         }
       }
     }
 
+    def compatibleTypes(dt1: Any, dt2: Any): Boolean =
+      (dt1, dt2) match {
+        case (_: IntegralType, _: IntegralType) => true
+        case (_: StringType, _: StringType) => true
+        case _ => false
+      }
+
+    def compatibleTypesIn(dt1: Any, dts: Seq[Any]): Boolean = {
+      dts.forall(compatibleTypes(dt1, _))
+    }
+
+    def fixValue(quotedValue: String, desiredType: DataType): Option[Any] = 
try {
+      val value = quotedValue.init.tail // remove leading and trailing quotes
+      desiredType match {
+        case LongType =>
+          Some(value.toLong)
+        case IntegerType =>
+          Some(value.toInt)
+        case ShortType =>
+          Some(value.toShort)
+        case ByteType =>
+          Some(value.toByte)
+      }
+    } catch {
+      case _: NumberFormatException =>
+        None
+    }
+
     def convert(expr: Expression): Option[String] = expr match {
-      case In(ExtractAttribute(SupportedAttribute(name)), 
ExtractableLiterals(values))
-          if useAdvanced =>
+      case In(ExtractAttribute(SupportedAttribute(name), dt1), 
ExtractableLiterals(valsAndDts))
+          if useAdvanced && compatibleTypesIn(dt1, valsAndDts.map(_._2)) =>
+        val values = valsAndDts.map(_._1)
         Some(convertInToOr(name, values))
 
-      case InSet(ExtractAttribute(SupportedAttribute(name)), 
ExtractableValues(values))
-          if useAdvanced =>
+      case InSet(ExtractAttribute(SupportedAttribute(name), dt1), 
ExtractableValues(valsAndDts))
+          if useAdvanced && compatibleTypesIn(dt1, valsAndDts.map(_._2)) =>
+        val values = valsAndDts.map(_._1)
         Some(convertInToOr(name, values))
 
       case op @ SpecialBinaryComparison(
-          ExtractAttribute(SupportedAttribute(name)), 
ExtractableLiteral(value)) =>
+          ExtractAttribute(SupportedAttribute(name), dt1), 
ExtractableLiteral(value, dt2))
+          if compatibleTypes(dt1, dt2) =>
         Some(s"$name ${op.symbol} $value")
 
       case op @ SpecialBinaryComparison(
-          ExtractableLiteral(value), 
ExtractAttribute(SupportedAttribute(name))) =>
+          ExtractAttribute(SupportedAttribute(name), dt1), 
ExtractableLiteral(rawValue, dt2))
+          if dt1.isInstanceOf[IntegralType] && dt2.isInstanceOf[StringType] =>
+        fixValue(rawValue, dt1).map { value =>

Review comment:
       >perhaps we should do it in UnwrapCastInBinaryComparison so that it can 
not only be used by Hive but also other data sources.
   
   Whatever makes sense. There is some (long-time) ongoing work with 
TypeCoercion (#22038) that fixes a few of these cases. But if if that goes 
through and we can close the gap with the others, that would be fine. I am 
probably not in a position to provide much help in the optimizer code (at this 
point).




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to