This is an automated email from the ASF dual-hosted git repository.

aokolnychyi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/iceberg.git


The following commit(s) were added to refs/heads/main by this push:
     new c745ac3b6a Spark 3.5: Support executor cache locality (#9563)
c745ac3b6a is described below

commit c745ac3b6a3b2f24ae5170aa18f9eeec7cf3cbc7
Author: Anton Okolnychyi <[email protected]>
AuthorDate: Mon Feb 5 10:33:51 2024 -0800

    Spark 3.5: Support executor cache locality (#9563)
---
 .../java/org/apache/iceberg/MockFileScanTask.java  |  11 ++
 .../spark/extensions/TestMergeOnReadDelete.java    |  25 +++
 .../org/apache/iceberg/spark/SparkReadConf.java    |  20 ++
 .../apache/iceberg/spark/SparkSQLProperties.java   |   4 +
 .../java/org/apache/iceberg/spark/SparkUtil.java   |  30 +++
 .../apache/iceberg/spark/source/SparkBatch.java    |  45 +++--
 .../iceberg/spark/source/SparkInputPartition.java  |  13 +-
 .../spark/source/SparkMicroBatchStream.java        |  32 ++--
 .../iceberg/spark/source/SparkPlanningUtil.java    |  93 +++++++++
 .../spark/source/TestSparkPlanningUtil.java        | 213 +++++++++++++++++++++
 10 files changed, 444 insertions(+), 42 deletions(-)

diff --git a/core/src/test/java/org/apache/iceberg/MockFileScanTask.java 
b/core/src/test/java/org/apache/iceberg/MockFileScanTask.java
index 58275ad3f0..565433c82c 100644
--- a/core/src/test/java/org/apache/iceberg/MockFileScanTask.java
+++ b/core/src/test/java/org/apache/iceberg/MockFileScanTask.java
@@ -44,6 +44,17 @@ public class MockFileScanTask extends BaseFileScanTask {
     this.length = file.fileSizeInBytes();
   }
 
+  public MockFileScanTask(DataFile file, Schema schema, PartitionSpec spec) {
+    super(file, null, SchemaParser.toJson(schema), 
PartitionSpecParser.toJson(spec), null);
+    this.length = file.fileSizeInBytes();
+  }
+
+  public MockFileScanTask(
+      DataFile file, DeleteFile[] deleteFiles, Schema schema, PartitionSpec 
spec) {
+    super(file, deleteFiles, SchemaParser.toJson(schema), 
PartitionSpecParser.toJson(spec), null);
+    this.length = file.fileSizeInBytes();
+  }
+
   public static MockFileScanTask mockTask(long length, int sortOrderId) {
     DataFile mockFile = Mockito.mock(DataFile.class);
     Mockito.when(mockFile.fileSizeInBytes()).thenReturn(length);
diff --git 
a/spark/v3.5/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestMergeOnReadDelete.java
 
b/spark/v3.5/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestMergeOnReadDelete.java
index 01f24c4dfe..91600d4df0 100644
--- 
a/spark/v3.5/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestMergeOnReadDelete.java
+++ 
b/spark/v3.5/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestMergeOnReadDelete.java
@@ -37,6 +37,7 @@ import 
org.apache.iceberg.exceptions.CommitStateUnknownException;
 import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList;
 import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap;
 import org.apache.iceberg.relocated.com.google.common.collect.Lists;
+import org.apache.iceberg.spark.SparkSQLProperties;
 import org.apache.iceberg.spark.source.SparkTable;
 import org.apache.iceberg.spark.source.TestSparkCatalog;
 import org.apache.iceberg.util.SnapshotUtil;
@@ -85,6 +86,30 @@ public class TestMergeOnReadDelete extends TestDelete {
     TestSparkCatalog.clearTables();
   }
 
+  @Test
+  public void testDeleteWithExecutorCacheLocality() throws 
NoSuchTableException {
+    createAndInitPartitionedTable();
+
+    append(tableName, new Employee(1, "hr"), new Employee(2, "hr"));
+    append(tableName, new Employee(3, "hr"), new Employee(4, "hr"));
+    append(tableName, new Employee(1, "hardware"), new Employee(2, 
"hardware"));
+    append(tableName, new Employee(3, "hardware"), new Employee(4, 
"hardware"));
+
+    createBranchIfNeeded();
+
+    withSQLConf(
+        ImmutableMap.of(SparkSQLProperties.EXECUTOR_CACHE_LOCALITY_ENABLED, 
"true"),
+        () -> {
+          sql("DELETE FROM %s WHERE id = 1", commitTarget());
+          sql("DELETE FROM %s WHERE id = 3", commitTarget());
+
+          assertEquals(
+              "Should have expected rows",
+              ImmutableList.of(row(2, "hardware"), row(2, "hr"), row(4, 
"hardware"), row(4, "hr")),
+              sql("SELECT * FROM %s ORDER BY id ASC, dep ASC", 
selectTarget()));
+        });
+  }
+
   @Test
   public void testDeleteFileGranularity() throws NoSuchTableException {
     checkDeleteFileGranularity(DeleteGranularity.FILE);
diff --git 
a/spark/v3.5/spark/src/main/java/org/apache/iceberg/spark/SparkReadConf.java 
b/spark/v3.5/spark/src/main/java/org/apache/iceberg/spark/SparkReadConf.java
index 984e2bce1e..2990d981d0 100644
--- a/spark/v3.5/spark/src/main/java/org/apache/iceberg/spark/SparkReadConf.java
+++ b/spark/v3.5/spark/src/main/java/org/apache/iceberg/spark/SparkReadConf.java
@@ -331,4 +331,24 @@ public class SparkReadConf {
     SparkConf sparkConf = spark.sparkContext().conf();
     return sparkConf.getSizeAsBytes(DRIVER_MAX_RESULT_SIZE, 
DRIVER_MAX_RESULT_SIZE_DEFAULT);
   }
+
+  public boolean executorCacheLocalityEnabled() {
+    return executorCacheEnabled() && executorCacheLocalityEnabledInternal();
+  }
+
+  private boolean executorCacheEnabled() {
+    return confParser
+        .booleanConf()
+        .sessionConf(SparkSQLProperties.EXECUTOR_CACHE_ENABLED)
+        .defaultValue(SparkSQLProperties.EXECUTOR_CACHE_ENABLED_DEFAULT)
+        .parse();
+  }
+
+  private boolean executorCacheLocalityEnabledInternal() {
+    return confParser
+        .booleanConf()
+        .sessionConf(SparkSQLProperties.EXECUTOR_CACHE_LOCALITY_ENABLED)
+        
.defaultValue(SparkSQLProperties.EXECUTOR_CACHE_LOCALITY_ENABLED_DEFAULT)
+        .parse();
+  }
 }
diff --git 
a/spark/v3.5/spark/src/main/java/org/apache/iceberg/spark/SparkSQLProperties.java
 
b/spark/v3.5/spark/src/main/java/org/apache/iceberg/spark/SparkSQLProperties.java
index 4a66520231..ea8f6fe071 100644
--- 
a/spark/v3.5/spark/src/main/java/org/apache/iceberg/spark/SparkSQLProperties.java
+++ 
b/spark/v3.5/spark/src/main/java/org/apache/iceberg/spark/SparkSQLProperties.java
@@ -86,4 +86,8 @@ public class SparkSQLProperties {
   public static final String EXECUTOR_CACHE_MAX_TOTAL_SIZE =
       "spark.sql.iceberg.executor-cache.max-total-size";
   public static final long EXECUTOR_CACHE_MAX_TOTAL_SIZE_DEFAULT = 128 * 1024 
* 1024; // 128 MB
+
+  public static final String EXECUTOR_CACHE_LOCALITY_ENABLED =
+      "spark.sql.iceberg.executor-cache.locality.enabled";
+  public static final boolean EXECUTOR_CACHE_LOCALITY_ENABLED_DEFAULT = false;
 }
diff --git 
a/spark/v3.5/spark/src/main/java/org/apache/iceberg/spark/SparkUtil.java 
b/spark/v3.5/spark/src/main/java/org/apache/iceberg/spark/SparkUtil.java
index 2357ca0441..de06cceb26 100644
--- a/spark/v3.5/spark/src/main/java/org/apache/iceberg/spark/SparkUtil.java
+++ b/spark/v3.5/spark/src/main/java/org/apache/iceberg/spark/SparkUtil.java
@@ -34,6 +34,8 @@ import 
org.apache.iceberg.relocated.com.google.common.collect.Lists;
 import org.apache.iceberg.transforms.Transform;
 import org.apache.iceberg.transforms.UnknownTransform;
 import org.apache.iceberg.util.Pair;
+import org.apache.spark.SparkEnv;
+import org.apache.spark.scheduler.ExecutorCacheTaskLocation;
 import org.apache.spark.sql.SparkSession;
 import org.apache.spark.sql.catalyst.expressions.BoundReference;
 import org.apache.spark.sql.catalyst.expressions.EqualTo;
@@ -43,7 +45,12 @@ import 
org.apache.spark.sql.connector.expressions.NamedReference;
 import org.apache.spark.sql.types.DataType;
 import org.apache.spark.sql.types.DataTypes;
 import org.apache.spark.sql.types.StructType;
+import org.apache.spark.storage.BlockManager;
+import org.apache.spark.storage.BlockManagerId;
+import org.apache.spark.storage.BlockManagerMaster;
 import org.joda.time.DateTime;
+import scala.collection.JavaConverters;
+import scala.collection.Seq;
 
 public class SparkUtil {
   private static final String SPARK_CATALOG_CONF_PREFIX = "spark.sql.catalog";
@@ -238,4 +245,27 @@ public class SparkUtil {
   public static boolean caseSensitive(SparkSession spark) {
     return Boolean.parseBoolean(spark.conf().get("spark.sql.caseSensitive"));
   }
+
+  public static List<String> executorLocations() {
+    BlockManager driverBlockManager = SparkEnv.get().blockManager();
+    List<BlockManagerId> executorBlockManagerIds = 
fetchPeers(driverBlockManager);
+    return executorBlockManagerIds.stream()
+        .map(SparkUtil::toExecutorLocation)
+        .sorted()
+        .collect(Collectors.toList());
+  }
+
+  private static List<BlockManagerId> fetchPeers(BlockManager blockManager) {
+    BlockManagerMaster master = blockManager.master();
+    BlockManagerId id = blockManager.blockManagerId();
+    return toJavaList(master.getPeers(id));
+  }
+
+  private static <T> List<T> toJavaList(Seq<T> seq) {
+    return JavaConverters.seqAsJavaListConverter(seq).asJava();
+  }
+
+  private static String toExecutorLocation(BlockManagerId id) {
+    return ExecutorCacheTaskLocation.apply(id.host(), 
id.executorId()).toString();
+  }
 }
diff --git 
a/spark/v3.5/spark/src/main/java/org/apache/iceberg/spark/source/SparkBatch.java
 
b/spark/v3.5/spark/src/main/java/org/apache/iceberg/spark/source/SparkBatch.java
index 4ed37a9f3d..fd6783f3e1 100644
--- 
a/spark/v3.5/spark/src/main/java/org/apache/iceberg/spark/source/SparkBatch.java
+++ 
b/spark/v3.5/spark/src/main/java/org/apache/iceberg/spark/source/SparkBatch.java
@@ -29,9 +29,8 @@ import org.apache.iceberg.Schema;
 import org.apache.iceberg.SchemaParser;
 import org.apache.iceberg.Table;
 import org.apache.iceberg.spark.SparkReadConf;
+import org.apache.iceberg.spark.SparkUtil;
 import org.apache.iceberg.types.Types;
-import org.apache.iceberg.util.Tasks;
-import org.apache.iceberg.util.ThreadPools;
 import org.apache.spark.api.java.JavaSparkContext;
 import org.apache.spark.broadcast.Broadcast;
 import org.apache.spark.sql.connector.read.Batch;
@@ -49,6 +48,7 @@ class SparkBatch implements Batch {
   private final Schema expectedSchema;
   private final boolean caseSensitive;
   private final boolean localityEnabled;
+  private final boolean executorCacheLocalityEnabled;
   private final int scanHashCode;
 
   SparkBatch(
@@ -68,6 +68,7 @@ class SparkBatch implements Batch {
     this.expectedSchema = expectedSchema;
     this.caseSensitive = readConf.caseSensitive();
     this.localityEnabled = readConf.localityEnabled();
+    this.executorCacheLocalityEnabled = 
readConf.executorCacheLocalityEnabled();
     this.scanHashCode = scanHashCode;
   }
 
@@ -77,27 +78,39 @@ class SparkBatch implements Batch {
     Broadcast<Table> tableBroadcast =
         sparkContext.broadcast(SerializableTableWithSize.copyOf(table));
     String expectedSchemaString = SchemaParser.toJson(expectedSchema);
+    String[][] locations = computePreferredLocations();
 
     InputPartition[] partitions = new InputPartition[taskGroups.size()];
 
-    Tasks.range(partitions.length)
-        .stopOnFailure()
-        .executeWith(localityEnabled ? ThreadPools.getWorkerPool() : null)
-        .run(
-            index ->
-                partitions[index] =
-                    new SparkInputPartition(
-                        groupingKeyType,
-                        taskGroups.get(index),
-                        tableBroadcast,
-                        branch,
-                        expectedSchemaString,
-                        caseSensitive,
-                        localityEnabled));
+    for (int index = 0; index < taskGroups.size(); index++) {
+      partitions[index] =
+          new SparkInputPartition(
+              groupingKeyType,
+              taskGroups.get(index),
+              tableBroadcast,
+              branch,
+              expectedSchemaString,
+              caseSensitive,
+              locations != null ? locations[index] : 
SparkPlanningUtil.NO_LOCATION_PREFERENCE);
+    }
 
     return partitions;
   }
 
+  private String[][] computePreferredLocations() {
+    if (localityEnabled) {
+      return SparkPlanningUtil.fetchBlockLocations(table.io(), taskGroups);
+
+    } else if (executorCacheLocalityEnabled) {
+      List<String> executorLocations = SparkUtil.executorLocations();
+      if (!executorLocations.isEmpty()) {
+        return SparkPlanningUtil.assignExecutors(taskGroups, 
executorLocations);
+      }
+    }
+
+    return null;
+  }
+
   @Override
   public PartitionReaderFactory createReaderFactory() {
     if (useParquetBatchReads()) {
diff --git 
a/spark/v3.5/spark/src/main/java/org/apache/iceberg/spark/source/SparkInputPartition.java
 
b/spark/v3.5/spark/src/main/java/org/apache/iceberg/spark/source/SparkInputPartition.java
index 0394b691e1..7826322be7 100644
--- 
a/spark/v3.5/spark/src/main/java/org/apache/iceberg/spark/source/SparkInputPartition.java
+++ 
b/spark/v3.5/spark/src/main/java/org/apache/iceberg/spark/source/SparkInputPartition.java
@@ -24,8 +24,6 @@ import org.apache.iceberg.ScanTaskGroup;
 import org.apache.iceberg.Schema;
 import org.apache.iceberg.SchemaParser;
 import org.apache.iceberg.Table;
-import org.apache.iceberg.hadoop.HadoopInputFile;
-import org.apache.iceberg.hadoop.Util;
 import org.apache.iceberg.types.Types;
 import org.apache.spark.broadcast.Broadcast;
 import org.apache.spark.sql.catalyst.InternalRow;
@@ -39,9 +37,9 @@ class SparkInputPartition implements InputPartition, 
HasPartitionKey, Serializab
   private final String branch;
   private final String expectedSchemaString;
   private final boolean caseSensitive;
+  private final transient String[] preferredLocations;
 
   private transient Schema expectedSchema = null;
-  private transient String[] preferredLocations = null;
 
   SparkInputPartition(
       Types.StructType groupingKeyType,
@@ -50,19 +48,14 @@ class SparkInputPartition implements InputPartition, 
HasPartitionKey, Serializab
       String branch,
       String expectedSchemaString,
       boolean caseSensitive,
-      boolean localityPreferred) {
+      String[] preferredLocations) {
     this.groupingKeyType = groupingKeyType;
     this.taskGroup = taskGroup;
     this.tableBroadcast = tableBroadcast;
     this.branch = branch;
     this.expectedSchemaString = expectedSchemaString;
     this.caseSensitive = caseSensitive;
-    if (localityPreferred) {
-      Table table = tableBroadcast.value();
-      this.preferredLocations = Util.blockLocations(table.io(), taskGroup);
-    } else {
-      this.preferredLocations = HadoopInputFile.NO_LOCATION_PREFERENCE;
-    }
+    this.preferredLocations = preferredLocations;
   }
 
   @Override
diff --git 
a/spark/v3.5/spark/src/main/java/org/apache/iceberg/spark/source/SparkMicroBatchStream.java
 
b/spark/v3.5/spark/src/main/java/org/apache/iceberg/spark/source/SparkMicroBatchStream.java
index 3ffd9904bb..320d2e14ad 100644
--- 
a/spark/v3.5/spark/src/main/java/org/apache/iceberg/spark/source/SparkMicroBatchStream.java
+++ 
b/spark/v3.5/spark/src/main/java/org/apache/iceberg/spark/source/SparkMicroBatchStream.java
@@ -54,8 +54,6 @@ import org.apache.iceberg.util.Pair;
 import org.apache.iceberg.util.PropertyUtil;
 import org.apache.iceberg.util.SnapshotUtil;
 import org.apache.iceberg.util.TableScanUtil;
-import org.apache.iceberg.util.Tasks;
-import org.apache.iceberg.util.ThreadPools;
 import org.apache.spark.api.java.JavaSparkContext;
 import org.apache.spark.broadcast.Broadcast;
 import org.apache.spark.sql.connector.read.InputPartition;
@@ -154,27 +152,29 @@ public class SparkMicroBatchStream implements 
MicroBatchStream, SupportsAdmissio
     List<CombinedScanTask> combinedScanTasks =
         Lists.newArrayList(
             TableScanUtil.planTasks(splitTasks, splitSize, splitLookback, 
splitOpenFileCost));
+    String[][] locations = computePreferredLocations(combinedScanTasks);
 
     InputPartition[] partitions = new InputPartition[combinedScanTasks.size()];
 
-    Tasks.range(partitions.length)
-        .stopOnFailure()
-        .executeWith(localityPreferred ? ThreadPools.getWorkerPool() : null)
-        .run(
-            index ->
-                partitions[index] =
-                    new SparkInputPartition(
-                        EMPTY_GROUPING_KEY_TYPE,
-                        combinedScanTasks.get(index),
-                        tableBroadcast,
-                        branch,
-                        expectedSchema,
-                        caseSensitive,
-                        localityPreferred));
+    for (int index = 0; index < combinedScanTasks.size(); index++) {
+      partitions[index] =
+          new SparkInputPartition(
+              EMPTY_GROUPING_KEY_TYPE,
+              combinedScanTasks.get(index),
+              tableBroadcast,
+              branch,
+              expectedSchema,
+              caseSensitive,
+              locations != null ? locations[index] : 
SparkPlanningUtil.NO_LOCATION_PREFERENCE);
+    }
 
     return partitions;
   }
 
+  private String[][] computePreferredLocations(List<CombinedScanTask> 
taskGroups) {
+    return localityPreferred ? 
SparkPlanningUtil.fetchBlockLocations(table.io(), taskGroups) : null;
+  }
+
   @Override
   public PartitionReaderFactory createReaderFactory() {
     return new SparkRowReaderFactory();
diff --git 
a/spark/v3.5/spark/src/main/java/org/apache/iceberg/spark/source/SparkPlanningUtil.java
 
b/spark/v3.5/spark/src/main/java/org/apache/iceberg/spark/source/SparkPlanningUtil.java
new file mode 100644
index 0000000000..9cdec2c8f4
--- /dev/null
+++ 
b/spark/v3.5/spark/src/main/java/org/apache/iceberg/spark/source/SparkPlanningUtil.java
@@ -0,0 +1,93 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.iceberg.spark.source;
+
+import java.util.List;
+import java.util.Map;
+import org.apache.iceberg.FileScanTask;
+import org.apache.iceberg.PartitionSpec;
+import org.apache.iceberg.ScanTask;
+import org.apache.iceberg.ScanTaskGroup;
+import org.apache.iceberg.StructLike;
+import org.apache.iceberg.hadoop.Util;
+import org.apache.iceberg.io.FileIO;
+import org.apache.iceberg.relocated.com.google.common.collect.Lists;
+import org.apache.iceberg.relocated.com.google.common.collect.Maps;
+import org.apache.iceberg.types.JavaHash;
+import org.apache.iceberg.util.Tasks;
+import org.apache.iceberg.util.ThreadPools;
+
+class SparkPlanningUtil {
+
+  public static final String[] NO_LOCATION_PREFERENCE = new String[0];
+
+  private SparkPlanningUtil() {}
+
+  public static String[][] fetchBlockLocations(
+      FileIO io, List<? extends ScanTaskGroup<?>> taskGroups) {
+    String[][] locations = new String[taskGroups.size()][];
+
+    Tasks.range(taskGroups.size())
+        .stopOnFailure()
+        .executeWith(ThreadPools.getWorkerPool())
+        .run(index -> locations[index] = Util.blockLocations(io, 
taskGroups.get(index)));
+
+    return locations;
+  }
+
+  public static String[][] assignExecutors(
+      List<? extends ScanTaskGroup<?>> taskGroups, List<String> 
executorLocations) {
+    Map<Integer, JavaHash<StructLike>> partitionHashes = Maps.newHashMap();
+    String[][] locations = new String[taskGroups.size()][];
+
+    for (int index = 0; index < taskGroups.size(); index++) {
+      locations[index] = assign(taskGroups.get(index), executorLocations, 
partitionHashes);
+    }
+
+    return locations;
+  }
+
+  private static String[] assign(
+      ScanTaskGroup<?> taskGroup,
+      List<String> executorLocations,
+      Map<Integer, JavaHash<StructLike>> partitionHashes) {
+    List<String> locations = Lists.newArrayList();
+
+    for (ScanTask task : taskGroup.tasks()) {
+      if (task.isFileScanTask()) {
+        FileScanTask fileTask = task.asFileScanTask();
+        PartitionSpec spec = fileTask.spec();
+        if (spec.isPartitioned() && !fileTask.deletes().isEmpty()) {
+          JavaHash<StructLike> partitionHash =
+              partitionHashes.computeIfAbsent(spec.specId(), key -> 
partitionHash(spec));
+          int partitionHashCode = partitionHash.hash(fileTask.partition());
+          int index = Math.floorMod(partitionHashCode, 
executorLocations.size());
+          String executorLocation = executorLocations.get(index);
+          locations.add(executorLocation);
+        }
+      }
+    }
+
+    return locations.toArray(NO_LOCATION_PREFERENCE);
+  }
+
+  private static JavaHash<StructLike> partitionHash(PartitionSpec spec) {
+    return JavaHash.forType(spec.partitionType());
+  }
+}
diff --git 
a/spark/v3.5/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkPlanningUtil.java
 
b/spark/v3.5/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkPlanningUtil.java
new file mode 100644
index 0000000000..65c6790e5b
--- /dev/null
+++ 
b/spark/v3.5/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkPlanningUtil.java
@@ -0,0 +1,213 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.iceberg.spark.source;
+
+import static org.apache.iceberg.types.Types.NestedField.required;
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.mockito.Mockito.when;
+
+import java.util.List;
+import org.apache.iceberg.BaseScanTaskGroup;
+import org.apache.iceberg.DataFile;
+import org.apache.iceberg.DataTask;
+import org.apache.iceberg.DeleteFile;
+import org.apache.iceberg.MockFileScanTask;
+import org.apache.iceberg.PartitionSpec;
+import org.apache.iceberg.ScanTask;
+import org.apache.iceberg.ScanTaskGroup;
+import org.apache.iceberg.Schema;
+import org.apache.iceberg.StructLike;
+import org.apache.iceberg.TestHelpers.Row;
+import org.apache.iceberg.io.CloseableIterable;
+import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList;
+import org.apache.iceberg.spark.TestBaseWithCatalog;
+import org.apache.iceberg.types.Types;
+import org.junit.jupiter.api.TestTemplate;
+import org.mockito.Mockito;
+
+public class TestSparkPlanningUtil extends TestBaseWithCatalog {
+
+  private static final Schema SCHEMA =
+      new Schema(
+          required(1, "id", Types.IntegerType.get()),
+          required(2, "data", Types.StringType.get()),
+          required(3, "category", Types.StringType.get()));
+  private static final PartitionSpec SPEC_1 =
+      PartitionSpec.builderFor(SCHEMA).withSpecId(1).bucket("id", 
16).identity("data").build();
+  private static final PartitionSpec SPEC_2 =
+      PartitionSpec.builderFor(SCHEMA).withSpecId(2).identity("data").build();
+  private static final List<String> EXECUTOR_LOCATIONS =
+      ImmutableList.of("host1_exec1", "host1_exec2", "host1_exec3", 
"host2_exec1", "host2_exec2");
+
+  @TestTemplate
+  public void testFileScanTaskWithoutDeletes() {
+    List<ScanTask> tasks =
+        ImmutableList.of(
+            new MockFileScanTask(mockDataFile(Row.of(1, "a")), SCHEMA, SPEC_1),
+            new MockFileScanTask(mockDataFile(Row.of(2, "b")), SCHEMA, SPEC_1),
+            new MockFileScanTask(mockDataFile(Row.of(3, "c")), SCHEMA, 
SPEC_1));
+    ScanTaskGroup<ScanTask> taskGroup = new BaseScanTaskGroup<>(tasks);
+    List<ScanTaskGroup<ScanTask>> taskGroups = ImmutableList.of(taskGroup);
+
+    String[][] locations = SparkPlanningUtil.assignExecutors(taskGroups, 
EXECUTOR_LOCATIONS);
+
+    // should not assign executors if there are no deletes
+    assertThat(locations.length).isEqualTo(1);
+    assertThat(locations[0]).isEmpty();
+  }
+
+  @TestTemplate
+  public void testFileScanTaskWithDeletes() {
+    StructLike partition1 = Row.of("k2", null);
+    StructLike partition2 = Row.of("k1");
+    List<ScanTask> tasks =
+        ImmutableList.of(
+            new MockFileScanTask(
+                mockDataFile(partition1), mockDeleteFiles(1, partition1), 
SCHEMA, SPEC_1),
+            new MockFileScanTask(
+                mockDataFile(partition2), mockDeleteFiles(3, partition2), 
SCHEMA, SPEC_2),
+            new MockFileScanTask(
+                mockDataFile(partition1), mockDeleteFiles(2, partition1), 
SCHEMA, SPEC_1));
+    ScanTaskGroup<ScanTask> taskGroup = new BaseScanTaskGroup<>(tasks);
+    List<ScanTaskGroup<ScanTask>> taskGroups = ImmutableList.of(taskGroup);
+
+    String[][] locations = SparkPlanningUtil.assignExecutors(taskGroups, 
EXECUTOR_LOCATIONS);
+
+    // should assign executors and handle different size of partitions
+    assertThat(locations.length).isEqualTo(1);
+    assertThat(locations[0].length).isGreaterThanOrEqualTo(1);
+  }
+
+  @TestTemplate
+  public void testFileScanTaskWithUnpartitionedDeletes() {
+    List<ScanTask> tasks1 =
+        ImmutableList.of(
+            new MockFileScanTask(
+                mockDataFile(Row.of()),
+                mockDeleteFiles(2, Row.of()),
+                SCHEMA,
+                PartitionSpec.unpartitioned()),
+            new MockFileScanTask(
+                mockDataFile(Row.of()),
+                mockDeleteFiles(2, Row.of()),
+                SCHEMA,
+                PartitionSpec.unpartitioned()),
+            new MockFileScanTask(
+                mockDataFile(Row.of()),
+                mockDeleteFiles(2, Row.of()),
+                SCHEMA,
+                PartitionSpec.unpartitioned()));
+    ScanTaskGroup<ScanTask> taskGroup1 = new BaseScanTaskGroup<>(tasks1);
+    List<ScanTask> tasks2 =
+        ImmutableList.of(
+            new MockFileScanTask(
+                mockDataFile(null),
+                mockDeleteFiles(2, null),
+                SCHEMA,
+                PartitionSpec.unpartitioned()),
+            new MockFileScanTask(
+                mockDataFile(null),
+                mockDeleteFiles(2, null),
+                SCHEMA,
+                PartitionSpec.unpartitioned()),
+            new MockFileScanTask(
+                mockDataFile(null),
+                mockDeleteFiles(2, null),
+                SCHEMA,
+                PartitionSpec.unpartitioned()));
+    ScanTaskGroup<ScanTask> taskGroup2 = new BaseScanTaskGroup<>(tasks2);
+    List<ScanTaskGroup<ScanTask>> taskGroups = ImmutableList.of(taskGroup1, 
taskGroup2);
+
+    String[][] locations = SparkPlanningUtil.assignExecutors(taskGroups, 
EXECUTOR_LOCATIONS);
+
+    // should not assign executors if the table is unpartitioned
+    assertThat(locations.length).isEqualTo(2);
+    assertThat(locations[0]).isEmpty();
+    assertThat(locations[1]).isEmpty();
+  }
+
+  @TestTemplate
+  public void testDataTasks() {
+    List<ScanTask> tasks =
+        ImmutableList.of(
+            new MockDataTask(mockDataFile(Row.of(1, "a"))),
+            new MockDataTask(mockDataFile(Row.of(2, "b"))),
+            new MockDataTask(mockDataFile(Row.of(3, "c"))));
+    ScanTaskGroup<ScanTask> taskGroup = new BaseScanTaskGroup<>(tasks);
+    List<ScanTaskGroup<ScanTask>> taskGroups = ImmutableList.of(taskGroup);
+
+    String[][] locations = SparkPlanningUtil.assignExecutors(taskGroups, 
EXECUTOR_LOCATIONS);
+
+    // should not assign executors for data tasks
+    assertThat(locations.length).isEqualTo(1);
+    assertThat(locations[0]).isEmpty();
+  }
+
+  @TestTemplate
+  public void testUnknownTasks() {
+    List<ScanTask> tasks = ImmutableList.of(new UnknownScanTask(), new 
UnknownScanTask());
+    ScanTaskGroup<ScanTask> taskGroup = new BaseScanTaskGroup<>(tasks);
+    List<ScanTaskGroup<ScanTask>> taskGroups = ImmutableList.of(taskGroup);
+
+    String[][] locations = SparkPlanningUtil.assignExecutors(taskGroups, 
EXECUTOR_LOCATIONS);
+
+    // should not assign executors for unknown tasks
+    assertThat(locations.length).isEqualTo(1);
+    assertThat(locations[0]).isEmpty();
+  }
+
+  private static DataFile mockDataFile(StructLike partition) {
+    DataFile file = Mockito.mock(DataFile.class);
+    when(file.partition()).thenReturn(partition);
+    return file;
+  }
+
+  private static DeleteFile[] mockDeleteFiles(int count, StructLike partition) 
{
+    DeleteFile[] files = new DeleteFile[count];
+    for (int index = 0; index < count; index++) {
+      files[index] = mockDeleteFile(partition);
+    }
+    return files;
+  }
+
+  private static DeleteFile mockDeleteFile(StructLike partition) {
+    DeleteFile file = Mockito.mock(DeleteFile.class);
+    when(file.partition()).thenReturn(partition);
+    return file;
+  }
+
+  private static class MockDataTask extends MockFileScanTask implements 
DataTask {
+
+    MockDataTask(DataFile file) {
+      super(file);
+    }
+
+    @Override
+    public PartitionSpec spec() {
+      return PartitionSpec.unpartitioned();
+    }
+
+    @Override
+    public CloseableIterable<StructLike> rows() {
+      throw new UnsupportedOperationException();
+    }
+  }
+
+  private static class UnknownScanTask implements ScanTask {}
+}

Reply via email to