Github user kiszk commented on a diff in the pull request:

    https://github.com/apache/spark/pull/21966#discussion_r207302021
  
    --- Diff: 
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
 ---
    @@ -4077,81 +4078,84 @@ case class ArrayExcept(left: Expression, right: 
Expression) extends ArraySetLike
       override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
         val arrayData = classOf[ArrayData].getName
         val i = ctx.freshName("i")
    -    val pos = ctx.freshName("pos")
         val value = ctx.freshName("value")
    -    val hsValue = ctx.freshName("hsValue")
         val size = ctx.freshName("size")
    -    if (elementTypeSupportEquals) {
    -      val ptName = CodeGenerator.primitiveTypeName(elementType)
    -      val unsafeArray = ctx.freshName("unsafeArray")
    -      val (postFix, openHashElementType, hsJavaTypeName, genHsValue,
    -           getter, setter, javaTypeName, primitiveTypeName, 
arrayDataBuilder) =
    -        elementType match {
    -          case ByteType | ShortType | IntegerType =>
    -            ("$mcI$sp", "Int", "int", s"(int) $value",
    -              s"get$ptName($i)", s"set$ptName($pos, $value)",
    -              CodeGenerator.javaType(elementType), ptName,
    -              s"""
    -                 |${ctx.createUnsafeArray(unsafeArray, size, elementType, 
s" $prettyName failed.")}
    -                 |${ev.value} = $unsafeArray;
    -               """.stripMargin)
    -          case LongType | FloatType | DoubleType =>
    -            val signature = elementType match {
    -              case LongType => "$mcJ$sp"
    -              case FloatType => "$mcF$sp"
    -              case DoubleType => "$mcD$sp"
    -            }
    -            (signature, CodeGenerator.boxedType(elementType),
    -              CodeGenerator.javaType(elementType), value,
    -              s"get$ptName($i)", s"set$ptName($pos, $value)",
    -              CodeGenerator.javaType(elementType), ptName,
    -              s"""
    -                 |${ctx.createUnsafeArray(unsafeArray, size, elementType, 
s" $prettyName failed.")}
    -                 |${ev.value} = $unsafeArray;
    -               """.stripMargin)
    -          case _ =>
    -            val genericArrayData = classOf[GenericArrayData].getName
    -            val et = ctx.addReferenceObj("elementType", elementType)
    -            ("", "Object", "Object", value,
    -              s"get($i, $et)", s"update($pos, $value)", "Object", "Ref",
    -              s"${ev.value} = new $genericArrayData(new Object[$size]);")
    -        }
    +    val canUseSpecializedHashSet = elementType match {
    +      case ByteType | ShortType | IntegerType | LongType | FloatType | 
DoubleType => true
    +      case _ => false
    +    }
    +    if (canUseSpecializedHashSet) {
    +      val jt = CodeGenerator.javaType(elementType)
    +      val ptName = CodeGenerator.primitiveTypeName(jt)
    +
    +      def genGetValue(array: String): String =
    +        CodeGenerator.getValue(array, elementType, i)
    +
    +      val (hsPostFix, hsTypeName) = elementType match {
    +        // we cast byte/short to int when writing to the hash set.
    +        case ByteType | ShortType | IntegerType => ("$mcI$sp", "Int")
    +        case LongType => ("$mcJ$sp", ptName)
    +        case FloatType => ("$mcF$sp", ptName)
    +        case DoubleType => ("$mcD$sp", ptName)
    +      }
    +
    +      // we cast byte/short to int when writing to the hash set.
    +      val hsValueCast = elementType match {
    +        case ByteType | ShortType => "(int) "
    +        case _ => ""
    +      }
     
           nullSafeCodeGen(ctx, ev, (array1, array2) => {
             val notFoundNullElement = ctx.freshName("notFoundNullElement")
             val nullElementIndex = ctx.freshName("nullElementIndex")
             val builder = ctx.freshName("builder")
    -        val array = ctx.freshName("array")
             val openHashSet = classOf[OpenHashSet[_]].getName
    -        val classTag = 
s"scala.reflect.ClassTag$$.MODULE$$.$openHashElementType()"
    -        val hs = ctx.freshName("hs")
    +        val classTag = s"scala.reflect.ClassTag$$.MODULE$$.$hsTypeName()"
    +        val hashSet = ctx.freshName("hashSet")
             val genericArrayData = classOf[GenericArrayData].getName
             val arrayBuilder = "scala.collection.mutable.ArrayBuilder"
    -        val arrayBuilderClass = s"$arrayBuilder$$of$primitiveTypeName"
    -        val arrayBuilderClassTag = if (primitiveTypeName != "Ref") {
    -          s"scala.reflect.ClassTag$$.MODULE$$.$primitiveTypeName()"
    -        } else {
    -          s"scala.reflect.ClassTag$$.MODULE$$.AnyRef()"
    -        }
    +        val arrayBuilderClass = s"$arrayBuilder$$of$ptName"
    +        val arrayBuilderClassTag = 
s"scala.reflect.ClassTag$$.MODULE$$.$ptName()"
     
    -        def withArray2NullCheck(body: String) =
    -          if (right.dataType.asInstanceOf[ArrayType].containsNull) {
    -            s"""
    -               |if ($array2.isNullAt($i)) {
    -               |  $notFoundNullElement = false;
    -               |} else {
    -               |  $body
    -               |}
    +        def withArray2NullCheck(body: String): String =
    +          if (left.dataType.asInstanceOf[ArrayType].containsNull) {
    --- End diff --
    
    Is it better to use the following structure to make `else` clause common?
    
    ```
    if (right.dataType.asInstanceOf[ArrayType].containsNull) {
      if (left.dataType.asInstanceOf[ArrayType].containsNull) {
        ...
      } else {
        ...
      }
    } else {
      body
    }
    ```


---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to