Hello everyone,

Let me take the following spark sql example to demonstrate the issue we're
having:

```
Select * FROM small_table
    Inner join big_table on small_table.foreign_key =
big_table.partition_key
    Inner join bigger_table on big_table.foreign_key =
bigger_table.partition_key
  where small_date.date="2022-01-01"
```

The efficient execution plan in the case above is to load the small table
filtered by date(resulting in a few rows), use it to partition prune the
big_table so only relevant partitions are loaded, then join the two tables
together, and use the join result to partition prune the bigger table.

I can't find a way to easily implement the strategy above in spark. Dynamic
partition pruning seems to support only one level of depth, so the
big_table is partition pruned, but the bigger table is always fully loaded.

Spark tries to parallelize things so it loads all the tables in parallel.
Interestingly, however, that is not the ideal approach in this case. I'm
not sure if spark has a mechanism to cancel pending tasks and adaptively
change physical execution strategy as new information comes in(in this for
example, spark ideally cancels loading the bigger_table, after the
small_table big_table join result is available and a small amount of rows
are returned. spark can use the resulting rows to partition prune the
bigger table assuming partition keys are in the join condition)

The only way I found to implement the strategy is to break the computation
in two steps, persist the first join result into disk, and then load it and
use it to partition and prune the bigger table. The code will be something
like this:


```
spark.sql("""
Select * FROM small_table
    Inner join big_table on small_table.foreign_key =
big_table.partition_key
    where small_date.date="2022-01-01"

""").write.format("parquet").mode("overwrite").save("path/to/test.parquet")

spark.read.format("parquet").load("path/to/test.parquet").createOrReplaceTempView("step1")

spark.sql("""
Select * FROM step_1
    Inner join bigger_table on step_1.foreign_key =
bigger_table.partition_key
    where step_1.date="2022-01-01"
""").collect()
```

I could not get `persist` to trigger computation for some reason(even after
running `count` on it), that's why I had to save it into a parquet, and
then reload it.


The issue with the code above apart from having to save it in disk is that
it requires manual rewriting queries which is not convenient especially for
queries with multiple joins.


I'm looking for some insights on how to efficiently execute the query above
without having to fetch full data of the joined tables.

Reply via email to