Hello everyone, Let me take the following spark sql example to demonstrate the issue we're having:
``` Select * FROM small_table Inner join big_table on small_table.foreign_key = big_table.partition_key Inner join bigger_table on big_table.foreign_key = bigger_table.partition_key where small_date.date="2022-01-01" ``` The efficient execution plan in the case above is to load the small table filtered by date(resulting in a few rows), use it to partition prune the big_table so only relevant partitions are loaded, then join the two tables together, and use the join result to partition prune the bigger table. I can't find a way to easily implement the strategy above in spark. Dynamic partition pruning seems to support only one level of depth, so the big_table is partition pruned, but the bigger table is always fully loaded. Spark tries to parallelize things so it loads all the tables in parallel. Interestingly, however, that is not the ideal approach in this case. I'm not sure if spark has a mechanism to cancel pending tasks and adaptively change physical execution strategy as new information comes in(in this for example, spark ideally cancels loading the bigger_table, after the small_table big_table join result is available and a small amount of rows are returned. spark can use the resulting rows to partition prune the bigger table assuming partition keys are in the join condition) The only way I found to implement the strategy is to break the computation in two steps, persist the first join result into disk, and then load it and use it to partition and prune the bigger table. The code will be something like this: ``` spark.sql(""" Select * FROM small_table Inner join big_table on small_table.foreign_key = big_table.partition_key where small_date.date="2022-01-01" """).write.format("parquet").mode("overwrite").save("path/to/test.parquet") spark.read.format("parquet").load("path/to/test.parquet").createOrReplaceTempView("step1") spark.sql(""" Select * FROM step_1 Inner join bigger_table on step_1.foreign_key = bigger_table.partition_key where step_1.date="2022-01-01" """).collect() ``` I could not get `persist` to trigger computation for some reason(even after running `count` on it), that's why I had to save it into a parquet, and then reload it. The issue with the code above apart from having to save it in disk is that it requires manual rewriting queries which is not convenient especially for queries with multiple joins. I'm looking for some insights on how to efficiently execute the query above without having to fetch full data of the joined tables.