Felix Wollschläger created SPARK-33383:
------------------------------------------

             Summary: Improve performance of Column.isin Expression
                 Key: SPARK-33383
                 URL: https://issues.apache.org/jira/browse/SPARK-33383
             Project: Spark
          Issue Type: Improvement
          Components: SQL
    Affects Versions: 3.0.1, 2.4.4
         Environment: macOS
Spark(-SQL) 2.4.4 and 3.0.1
Scala 2.12.10
            Reporter: Felix Wollschläger


When I asked [a question on 
Stackoverflow|https://stackoverflow.com/questions/64683189/usage-of-broadcast-variables-when-using-only-spark-sql-api]
 and running some local tests, I came across a performance bottleneck when 
using the `where`-Condition `Column.isin`.

I have a set of allowed-values ("whitelist") with a size that's handleable 
in-memory really good (about 10k values). I thought simply using the 
`Column.isin` Expression in the SQL API should be the way to go. I assumed it 
would be runtime equivalent to
```scala
df.filter(row => allowedValues.contains(row.getInt(0)))
```

however, when running a few tests locally, I realized that using `Column.isin` 
is actually about 10 times slower than a ```rdd.filter``` or a 
broadcast-inner-join.

Shouldn't ```df.where(col("colname").isin(allowedValues))``` perform (SQL-API 
overhead aside) perform as good as ```df.filter(row => 
allowedValues.contains(row.getInt(0)))``` ?

```scala
package example

import org.apache.spark.sql.functions.{broadcast, col, count}
import org.apache.spark.sql.{DataFrame, SparkSession}

import scala.util.Random

object Test {

    def main(args: Array[String]): Unit = {
        val spark = SparkSession.builder()
            .appName("Name")
            .master("local[*]")
            .config("spark.driver.host", "localhost")
            .config("spark.ui.enabled", "false")
            .getOrCreate()

        import spark.implicits._

        val _10Million = 10000000
        val random = new Random(1048394789305L)

        val values = Seq.fill(_10Million)(random.nextInt())
        val df = Seq.fill(_10Million)(random.nextInt()).toDF("value")
        val allowedValues = getRandomElements(values, random, 10000)

        println("Starting ...")
        runWithInCollection(spark, df, allowedValues)
        println("---- In Collection")
        runWithBroadcastDF(spark, df, allowedValues)
        println("---- Broadcast DF")
        runWithBroadcastVariable(spark, df, allowedValues)
        println("---- Broadcast Variable")
    }

    def getRandomElements[A](seq: Seq[A], random: Random, size: Int): Set[A] = {
        val builder = Set.newBuilder[A]

        for (i <- 0 until size) {
            builder += getRandomElement(seq, random)
        }

        builder.result()
    }

    def getRandomElement[A](seq: Seq[A], random: Random): A = {
        seq(random.nextInt(seq.length))
    }

    // I expected this one to be almost equivalent to the one with a 
broadcast-variable, but it's actually about 10 times slower
    def runWithInCollection(spark: SparkSession, df: DataFrame, allowedValues: 
Set[Int]): Unit = {
        spark.time {
            
df.where(col("value").isInCollection(allowedValues)).runTestAggregation()
        }
    }

    // A bit slower than the one with a broadcast variable
    def runWithBroadcastDF(spark: SparkSession, df: DataFrame, allowedValues: 
Set[Int]): Unit = {
        import spark.implicits._

        val allowedValuesDF = allowedValues.toSeq.toDF("allowedValue")

        spark.time {
            df.join(broadcast(allowedValuesDF), col("value") === 
col("allowedValue")).runTestAggregation()
        }
    }

    // This is actually the fastest one
    def runWithBroadcastVariable(spark: SparkSession, df: DataFrame, 
allowedValues: Set[Int]): Unit = {
        val allowedValuesBroadcast = spark.sparkContext.broadcast(allowedValues)

        spark.time {
            df.filter(row => 
allowedValuesBroadcast.value.contains(row.getInt(0))).runTestAggregation()
        }
    }

    implicit class TestRunner(val df: DataFrame) {

        def runTestAggregation(): Unit = {
            df.agg(count("value")).show()
        }
    }
}
```



--
This message was sent by Atlassian Jira
(v8.3.4#803005)

---------------------------------------------------------------------
To unsubscribe, e-mail: issues-unsubscr...@spark.apache.org
For additional commands, e-mail: issues-h...@spark.apache.org

Reply via email to