Github user ueshin commented on a diff in the pull request: https://github.com/apache/spark/pull/21966#discussion_r207289175 --- Diff: sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala --- @@ -4077,81 +4078,84 @@ case class ArrayExcept(left: Expression, right: Expression) extends ArraySetLike override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val arrayData = classOf[ArrayData].getName val i = ctx.freshName("i") - val pos = ctx.freshName("pos") val value = ctx.freshName("value") - val hsValue = ctx.freshName("hsValue") val size = ctx.freshName("size") - if (elementTypeSupportEquals) { - val ptName = CodeGenerator.primitiveTypeName(elementType) - val unsafeArray = ctx.freshName("unsafeArray") - val (postFix, openHashElementType, hsJavaTypeName, genHsValue, - getter, setter, javaTypeName, primitiveTypeName, arrayDataBuilder) = - elementType match { - case ByteType | ShortType | IntegerType => - ("$mcI$sp", "Int", "int", s"(int) $value", - s"get$ptName($i)", s"set$ptName($pos, $value)", - CodeGenerator.javaType(elementType), ptName, - s""" - |${ctx.createUnsafeArray(unsafeArray, size, elementType, s" $prettyName failed.")} - |${ev.value} = $unsafeArray; - """.stripMargin) - case LongType | FloatType | DoubleType => - val signature = elementType match { - case LongType => "$mcJ$sp" - case FloatType => "$mcF$sp" - case DoubleType => "$mcD$sp" - } - (signature, CodeGenerator.boxedType(elementType), - CodeGenerator.javaType(elementType), value, - s"get$ptName($i)", s"set$ptName($pos, $value)", - CodeGenerator.javaType(elementType), ptName, - s""" - |${ctx.createUnsafeArray(unsafeArray, size, elementType, s" $prettyName failed.")} - |${ev.value} = $unsafeArray; - """.stripMargin) - case _ => - val genericArrayData = classOf[GenericArrayData].getName - val et = ctx.addReferenceObj("elementType", elementType) - ("", "Object", "Object", value, - s"get($i, $et)", s"update($pos, $value)", "Object", "Ref", - s"${ev.value} = new $genericArrayData(new Object[$size]);") - } + val canUseSpecializedHashSet = elementType match { + case ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType => true + case _ => false + } + if (canUseSpecializedHashSet) { + val jt = CodeGenerator.javaType(elementType) + val ptName = CodeGenerator.primitiveTypeName(jt) + + def genGetValue(array: String): String = + CodeGenerator.getValue(array, elementType, i) + + val (hsPostFix, hsTypeName) = elementType match { + // we cast byte/short to int when writing to the hash set. + case ByteType | ShortType | IntegerType => ("$mcI$sp", "Int") + case LongType => ("$mcJ$sp", ptName) + case FloatType => ("$mcF$sp", ptName) + case DoubleType => ("$mcD$sp", ptName) + } + + // we cast byte/short to int when writing to the hash set. + val hsValueCast = elementType match { + case ByteType | ShortType => "(int) " + case _ => "" + } nullSafeCodeGen(ctx, ev, (array1, array2) => { val notFoundNullElement = ctx.freshName("notFoundNullElement") val nullElementIndex = ctx.freshName("nullElementIndex") val builder = ctx.freshName("builder") - val array = ctx.freshName("array") val openHashSet = classOf[OpenHashSet[_]].getName - val classTag = s"scala.reflect.ClassTag$$.MODULE$$.$openHashElementType()" - val hs = ctx.freshName("hs") + val classTag = s"scala.reflect.ClassTag$$.MODULE$$.$hsTypeName()" + val hashSet = ctx.freshName("hashSet") val genericArrayData = classOf[GenericArrayData].getName val arrayBuilder = "scala.collection.mutable.ArrayBuilder" - val arrayBuilderClass = s"$arrayBuilder$$of$primitiveTypeName" - val arrayBuilderClassTag = if (primitiveTypeName != "Ref") { - s"scala.reflect.ClassTag$$.MODULE$$.$primitiveTypeName()" - } else { - s"scala.reflect.ClassTag$$.MODULE$$.AnyRef()" - } + val arrayBuilderClass = s"$arrayBuilder$$of$ptName" + val arrayBuilderClassTag = s"scala.reflect.ClassTag$$.MODULE$$.$ptName()" - def withArray2NullCheck(body: String) = - if (right.dataType.asInstanceOf[ArrayType].containsNull) { - s""" - |if ($array2.isNullAt($i)) { - | $notFoundNullElement = false; - |} else { - | $body - |} + def withArray2NullCheck(body: String): String = + if (left.dataType.asInstanceOf[ArrayType].containsNull) { + if (right.dataType.asInstanceOf[ArrayType].containsNull) { + s""" + |if ($array2.isNullAt($i)) { + | $notFoundNullElement = false; + |} else { + | $body + |} """.stripMargin + } else { + body + } } else { - body + // if array1's element is not nullable, we don't need to track the null element index. + if (right.dataType.asInstanceOf[ArrayType].containsNull) { + s""" + |if (!$array2.isNullAt($i)) { + | $body + |} + """.stripMargin + } else { + body + } } - val array2Body = + + val writeArray2ToHashSet = withArray2NullCheck( s""" - |$javaTypeName $value = $array2.$getter; - |$hsJavaTypeName $hsValue = $genHsValue; - |$hs.add$postFix($hsValue); - """.stripMargin + |$jt $value = ${genGetValue(array2)}; + |$hashSet.add$hsPostFix($hsValueCast$value); + """.stripMargin) + + // When hitting a null vale, put a null holder in the ArrayBuilder. Finally we will + // convert ArrayBuilder to ArrayData and setNull on the slot with null holder. + val nullValueHolder = elementType match { + case ByteType => "(byte) 0" + case ShortType => "(short ) 0" --- End diff -- nit: extra space after `short`
--- --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org