This is an automated email from the ASF dual-hosted git repository.
aokolnychyi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/iceberg.git
The following commit(s) were added to refs/heads/main by this push:
new c745ac3b6a Spark 3.5: Support executor cache locality (#9563)
c745ac3b6a is described below
commit c745ac3b6a3b2f24ae5170aa18f9eeec7cf3cbc7
Author: Anton Okolnychyi <[email protected]>
AuthorDate: Mon Feb 5 10:33:51 2024 -0800
Spark 3.5: Support executor cache locality (#9563)
---
.../java/org/apache/iceberg/MockFileScanTask.java | 11 ++
.../spark/extensions/TestMergeOnReadDelete.java | 25 +++
.../org/apache/iceberg/spark/SparkReadConf.java | 20 ++
.../apache/iceberg/spark/SparkSQLProperties.java | 4 +
.../java/org/apache/iceberg/spark/SparkUtil.java | 30 +++
.../apache/iceberg/spark/source/SparkBatch.java | 45 +++--
.../iceberg/spark/source/SparkInputPartition.java | 13 +-
.../spark/source/SparkMicroBatchStream.java | 32 ++--
.../iceberg/spark/source/SparkPlanningUtil.java | 93 +++++++++
.../spark/source/TestSparkPlanningUtil.java | 213 +++++++++++++++++++++
10 files changed, 444 insertions(+), 42 deletions(-)
diff --git a/core/src/test/java/org/apache/iceberg/MockFileScanTask.java
b/core/src/test/java/org/apache/iceberg/MockFileScanTask.java
index 58275ad3f0..565433c82c 100644
--- a/core/src/test/java/org/apache/iceberg/MockFileScanTask.java
+++ b/core/src/test/java/org/apache/iceberg/MockFileScanTask.java
@@ -44,6 +44,17 @@ public class MockFileScanTask extends BaseFileScanTask {
this.length = file.fileSizeInBytes();
}
+ public MockFileScanTask(DataFile file, Schema schema, PartitionSpec spec) {
+ super(file, null, SchemaParser.toJson(schema),
PartitionSpecParser.toJson(spec), null);
+ this.length = file.fileSizeInBytes();
+ }
+
+ public MockFileScanTask(
+ DataFile file, DeleteFile[] deleteFiles, Schema schema, PartitionSpec
spec) {
+ super(file, deleteFiles, SchemaParser.toJson(schema),
PartitionSpecParser.toJson(spec), null);
+ this.length = file.fileSizeInBytes();
+ }
+
public static MockFileScanTask mockTask(long length, int sortOrderId) {
DataFile mockFile = Mockito.mock(DataFile.class);
Mockito.when(mockFile.fileSizeInBytes()).thenReturn(length);
diff --git
a/spark/v3.5/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestMergeOnReadDelete.java
b/spark/v3.5/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestMergeOnReadDelete.java
index 01f24c4dfe..91600d4df0 100644
---
a/spark/v3.5/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestMergeOnReadDelete.java
+++
b/spark/v3.5/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestMergeOnReadDelete.java
@@ -37,6 +37,7 @@ import
org.apache.iceberg.exceptions.CommitStateUnknownException;
import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList;
import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap;
import org.apache.iceberg.relocated.com.google.common.collect.Lists;
+import org.apache.iceberg.spark.SparkSQLProperties;
import org.apache.iceberg.spark.source.SparkTable;
import org.apache.iceberg.spark.source.TestSparkCatalog;
import org.apache.iceberg.util.SnapshotUtil;
@@ -85,6 +86,30 @@ public class TestMergeOnReadDelete extends TestDelete {
TestSparkCatalog.clearTables();
}
+ @Test
+ public void testDeleteWithExecutorCacheLocality() throws
NoSuchTableException {
+ createAndInitPartitionedTable();
+
+ append(tableName, new Employee(1, "hr"), new Employee(2, "hr"));
+ append(tableName, new Employee(3, "hr"), new Employee(4, "hr"));
+ append(tableName, new Employee(1, "hardware"), new Employee(2,
"hardware"));
+ append(tableName, new Employee(3, "hardware"), new Employee(4,
"hardware"));
+
+ createBranchIfNeeded();
+
+ withSQLConf(
+ ImmutableMap.of(SparkSQLProperties.EXECUTOR_CACHE_LOCALITY_ENABLED,
"true"),
+ () -> {
+ sql("DELETE FROM %s WHERE id = 1", commitTarget());
+ sql("DELETE FROM %s WHERE id = 3", commitTarget());
+
+ assertEquals(
+ "Should have expected rows",
+ ImmutableList.of(row(2, "hardware"), row(2, "hr"), row(4,
"hardware"), row(4, "hr")),
+ sql("SELECT * FROM %s ORDER BY id ASC, dep ASC",
selectTarget()));
+ });
+ }
+
@Test
public void testDeleteFileGranularity() throws NoSuchTableException {
checkDeleteFileGranularity(DeleteGranularity.FILE);
diff --git
a/spark/v3.5/spark/src/main/java/org/apache/iceberg/spark/SparkReadConf.java
b/spark/v3.5/spark/src/main/java/org/apache/iceberg/spark/SparkReadConf.java
index 984e2bce1e..2990d981d0 100644
--- a/spark/v3.5/spark/src/main/java/org/apache/iceberg/spark/SparkReadConf.java
+++ b/spark/v3.5/spark/src/main/java/org/apache/iceberg/spark/SparkReadConf.java
@@ -331,4 +331,24 @@ public class SparkReadConf {
SparkConf sparkConf = spark.sparkContext().conf();
return sparkConf.getSizeAsBytes(DRIVER_MAX_RESULT_SIZE,
DRIVER_MAX_RESULT_SIZE_DEFAULT);
}
+
+ public boolean executorCacheLocalityEnabled() {
+ return executorCacheEnabled() && executorCacheLocalityEnabledInternal();
+ }
+
+ private boolean executorCacheEnabled() {
+ return confParser
+ .booleanConf()
+ .sessionConf(SparkSQLProperties.EXECUTOR_CACHE_ENABLED)
+ .defaultValue(SparkSQLProperties.EXECUTOR_CACHE_ENABLED_DEFAULT)
+ .parse();
+ }
+
+ private boolean executorCacheLocalityEnabledInternal() {
+ return confParser
+ .booleanConf()
+ .sessionConf(SparkSQLProperties.EXECUTOR_CACHE_LOCALITY_ENABLED)
+
.defaultValue(SparkSQLProperties.EXECUTOR_CACHE_LOCALITY_ENABLED_DEFAULT)
+ .parse();
+ }
}
diff --git
a/spark/v3.5/spark/src/main/java/org/apache/iceberg/spark/SparkSQLProperties.java
b/spark/v3.5/spark/src/main/java/org/apache/iceberg/spark/SparkSQLProperties.java
index 4a66520231..ea8f6fe071 100644
---
a/spark/v3.5/spark/src/main/java/org/apache/iceberg/spark/SparkSQLProperties.java
+++
b/spark/v3.5/spark/src/main/java/org/apache/iceberg/spark/SparkSQLProperties.java
@@ -86,4 +86,8 @@ public class SparkSQLProperties {
public static final String EXECUTOR_CACHE_MAX_TOTAL_SIZE =
"spark.sql.iceberg.executor-cache.max-total-size";
public static final long EXECUTOR_CACHE_MAX_TOTAL_SIZE_DEFAULT = 128 * 1024
* 1024; // 128 MB
+
+ public static final String EXECUTOR_CACHE_LOCALITY_ENABLED =
+ "spark.sql.iceberg.executor-cache.locality.enabled";
+ public static final boolean EXECUTOR_CACHE_LOCALITY_ENABLED_DEFAULT = false;
}
diff --git
a/spark/v3.5/spark/src/main/java/org/apache/iceberg/spark/SparkUtil.java
b/spark/v3.5/spark/src/main/java/org/apache/iceberg/spark/SparkUtil.java
index 2357ca0441..de06cceb26 100644
--- a/spark/v3.5/spark/src/main/java/org/apache/iceberg/spark/SparkUtil.java
+++ b/spark/v3.5/spark/src/main/java/org/apache/iceberg/spark/SparkUtil.java
@@ -34,6 +34,8 @@ import
org.apache.iceberg.relocated.com.google.common.collect.Lists;
import org.apache.iceberg.transforms.Transform;
import org.apache.iceberg.transforms.UnknownTransform;
import org.apache.iceberg.util.Pair;
+import org.apache.spark.SparkEnv;
+import org.apache.spark.scheduler.ExecutorCacheTaskLocation;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.catalyst.expressions.BoundReference;
import org.apache.spark.sql.catalyst.expressions.EqualTo;
@@ -43,7 +45,12 @@ import
org.apache.spark.sql.connector.expressions.NamedReference;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructType;
+import org.apache.spark.storage.BlockManager;
+import org.apache.spark.storage.BlockManagerId;
+import org.apache.spark.storage.BlockManagerMaster;
import org.joda.time.DateTime;
+import scala.collection.JavaConverters;
+import scala.collection.Seq;
public class SparkUtil {
private static final String SPARK_CATALOG_CONF_PREFIX = "spark.sql.catalog";
@@ -238,4 +245,27 @@ public class SparkUtil {
public static boolean caseSensitive(SparkSession spark) {
return Boolean.parseBoolean(spark.conf().get("spark.sql.caseSensitive"));
}
+
+ public static List<String> executorLocations() {
+ BlockManager driverBlockManager = SparkEnv.get().blockManager();
+ List<BlockManagerId> executorBlockManagerIds =
fetchPeers(driverBlockManager);
+ return executorBlockManagerIds.stream()
+ .map(SparkUtil::toExecutorLocation)
+ .sorted()
+ .collect(Collectors.toList());
+ }
+
+ private static List<BlockManagerId> fetchPeers(BlockManager blockManager) {
+ BlockManagerMaster master = blockManager.master();
+ BlockManagerId id = blockManager.blockManagerId();
+ return toJavaList(master.getPeers(id));
+ }
+
+ private static <T> List<T> toJavaList(Seq<T> seq) {
+ return JavaConverters.seqAsJavaListConverter(seq).asJava();
+ }
+
+ private static String toExecutorLocation(BlockManagerId id) {
+ return ExecutorCacheTaskLocation.apply(id.host(),
id.executorId()).toString();
+ }
}
diff --git
a/spark/v3.5/spark/src/main/java/org/apache/iceberg/spark/source/SparkBatch.java
b/spark/v3.5/spark/src/main/java/org/apache/iceberg/spark/source/SparkBatch.java
index 4ed37a9f3d..fd6783f3e1 100644
---
a/spark/v3.5/spark/src/main/java/org/apache/iceberg/spark/source/SparkBatch.java
+++
b/spark/v3.5/spark/src/main/java/org/apache/iceberg/spark/source/SparkBatch.java
@@ -29,9 +29,8 @@ import org.apache.iceberg.Schema;
import org.apache.iceberg.SchemaParser;
import org.apache.iceberg.Table;
import org.apache.iceberg.spark.SparkReadConf;
+import org.apache.iceberg.spark.SparkUtil;
import org.apache.iceberg.types.Types;
-import org.apache.iceberg.util.Tasks;
-import org.apache.iceberg.util.ThreadPools;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.broadcast.Broadcast;
import org.apache.spark.sql.connector.read.Batch;
@@ -49,6 +48,7 @@ class SparkBatch implements Batch {
private final Schema expectedSchema;
private final boolean caseSensitive;
private final boolean localityEnabled;
+ private final boolean executorCacheLocalityEnabled;
private final int scanHashCode;
SparkBatch(
@@ -68,6 +68,7 @@ class SparkBatch implements Batch {
this.expectedSchema = expectedSchema;
this.caseSensitive = readConf.caseSensitive();
this.localityEnabled = readConf.localityEnabled();
+ this.executorCacheLocalityEnabled =
readConf.executorCacheLocalityEnabled();
this.scanHashCode = scanHashCode;
}
@@ -77,27 +78,39 @@ class SparkBatch implements Batch {
Broadcast<Table> tableBroadcast =
sparkContext.broadcast(SerializableTableWithSize.copyOf(table));
String expectedSchemaString = SchemaParser.toJson(expectedSchema);
+ String[][] locations = computePreferredLocations();
InputPartition[] partitions = new InputPartition[taskGroups.size()];
- Tasks.range(partitions.length)
- .stopOnFailure()
- .executeWith(localityEnabled ? ThreadPools.getWorkerPool() : null)
- .run(
- index ->
- partitions[index] =
- new SparkInputPartition(
- groupingKeyType,
- taskGroups.get(index),
- tableBroadcast,
- branch,
- expectedSchemaString,
- caseSensitive,
- localityEnabled));
+ for (int index = 0; index < taskGroups.size(); index++) {
+ partitions[index] =
+ new SparkInputPartition(
+ groupingKeyType,
+ taskGroups.get(index),
+ tableBroadcast,
+ branch,
+ expectedSchemaString,
+ caseSensitive,
+ locations != null ? locations[index] :
SparkPlanningUtil.NO_LOCATION_PREFERENCE);
+ }
return partitions;
}
+ private String[][] computePreferredLocations() {
+ if (localityEnabled) {
+ return SparkPlanningUtil.fetchBlockLocations(table.io(), taskGroups);
+
+ } else if (executorCacheLocalityEnabled) {
+ List<String> executorLocations = SparkUtil.executorLocations();
+ if (!executorLocations.isEmpty()) {
+ return SparkPlanningUtil.assignExecutors(taskGroups,
executorLocations);
+ }
+ }
+
+ return null;
+ }
+
@Override
public PartitionReaderFactory createReaderFactory() {
if (useParquetBatchReads()) {
diff --git
a/spark/v3.5/spark/src/main/java/org/apache/iceberg/spark/source/SparkInputPartition.java
b/spark/v3.5/spark/src/main/java/org/apache/iceberg/spark/source/SparkInputPartition.java
index 0394b691e1..7826322be7 100644
---
a/spark/v3.5/spark/src/main/java/org/apache/iceberg/spark/source/SparkInputPartition.java
+++
b/spark/v3.5/spark/src/main/java/org/apache/iceberg/spark/source/SparkInputPartition.java
@@ -24,8 +24,6 @@ import org.apache.iceberg.ScanTaskGroup;
import org.apache.iceberg.Schema;
import org.apache.iceberg.SchemaParser;
import org.apache.iceberg.Table;
-import org.apache.iceberg.hadoop.HadoopInputFile;
-import org.apache.iceberg.hadoop.Util;
import org.apache.iceberg.types.Types;
import org.apache.spark.broadcast.Broadcast;
import org.apache.spark.sql.catalyst.InternalRow;
@@ -39,9 +37,9 @@ class SparkInputPartition implements InputPartition,
HasPartitionKey, Serializab
private final String branch;
private final String expectedSchemaString;
private final boolean caseSensitive;
+ private final transient String[] preferredLocations;
private transient Schema expectedSchema = null;
- private transient String[] preferredLocations = null;
SparkInputPartition(
Types.StructType groupingKeyType,
@@ -50,19 +48,14 @@ class SparkInputPartition implements InputPartition,
HasPartitionKey, Serializab
String branch,
String expectedSchemaString,
boolean caseSensitive,
- boolean localityPreferred) {
+ String[] preferredLocations) {
this.groupingKeyType = groupingKeyType;
this.taskGroup = taskGroup;
this.tableBroadcast = tableBroadcast;
this.branch = branch;
this.expectedSchemaString = expectedSchemaString;
this.caseSensitive = caseSensitive;
- if (localityPreferred) {
- Table table = tableBroadcast.value();
- this.preferredLocations = Util.blockLocations(table.io(), taskGroup);
- } else {
- this.preferredLocations = HadoopInputFile.NO_LOCATION_PREFERENCE;
- }
+ this.preferredLocations = preferredLocations;
}
@Override
diff --git
a/spark/v3.5/spark/src/main/java/org/apache/iceberg/spark/source/SparkMicroBatchStream.java
b/spark/v3.5/spark/src/main/java/org/apache/iceberg/spark/source/SparkMicroBatchStream.java
index 3ffd9904bb..320d2e14ad 100644
---
a/spark/v3.5/spark/src/main/java/org/apache/iceberg/spark/source/SparkMicroBatchStream.java
+++
b/spark/v3.5/spark/src/main/java/org/apache/iceberg/spark/source/SparkMicroBatchStream.java
@@ -54,8 +54,6 @@ import org.apache.iceberg.util.Pair;
import org.apache.iceberg.util.PropertyUtil;
import org.apache.iceberg.util.SnapshotUtil;
import org.apache.iceberg.util.TableScanUtil;
-import org.apache.iceberg.util.Tasks;
-import org.apache.iceberg.util.ThreadPools;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.broadcast.Broadcast;
import org.apache.spark.sql.connector.read.InputPartition;
@@ -154,27 +152,29 @@ public class SparkMicroBatchStream implements
MicroBatchStream, SupportsAdmissio
List<CombinedScanTask> combinedScanTasks =
Lists.newArrayList(
TableScanUtil.planTasks(splitTasks, splitSize, splitLookback,
splitOpenFileCost));
+ String[][] locations = computePreferredLocations(combinedScanTasks);
InputPartition[] partitions = new InputPartition[combinedScanTasks.size()];
- Tasks.range(partitions.length)
- .stopOnFailure()
- .executeWith(localityPreferred ? ThreadPools.getWorkerPool() : null)
- .run(
- index ->
- partitions[index] =
- new SparkInputPartition(
- EMPTY_GROUPING_KEY_TYPE,
- combinedScanTasks.get(index),
- tableBroadcast,
- branch,
- expectedSchema,
- caseSensitive,
- localityPreferred));
+ for (int index = 0; index < combinedScanTasks.size(); index++) {
+ partitions[index] =
+ new SparkInputPartition(
+ EMPTY_GROUPING_KEY_TYPE,
+ combinedScanTasks.get(index),
+ tableBroadcast,
+ branch,
+ expectedSchema,
+ caseSensitive,
+ locations != null ? locations[index] :
SparkPlanningUtil.NO_LOCATION_PREFERENCE);
+ }
return partitions;
}
+ private String[][] computePreferredLocations(List<CombinedScanTask>
taskGroups) {
+ return localityPreferred ?
SparkPlanningUtil.fetchBlockLocations(table.io(), taskGroups) : null;
+ }
+
@Override
public PartitionReaderFactory createReaderFactory() {
return new SparkRowReaderFactory();
diff --git
a/spark/v3.5/spark/src/main/java/org/apache/iceberg/spark/source/SparkPlanningUtil.java
b/spark/v3.5/spark/src/main/java/org/apache/iceberg/spark/source/SparkPlanningUtil.java
new file mode 100644
index 0000000000..9cdec2c8f4
--- /dev/null
+++
b/spark/v3.5/spark/src/main/java/org/apache/iceberg/spark/source/SparkPlanningUtil.java
@@ -0,0 +1,93 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.iceberg.spark.source;
+
+import java.util.List;
+import java.util.Map;
+import org.apache.iceberg.FileScanTask;
+import org.apache.iceberg.PartitionSpec;
+import org.apache.iceberg.ScanTask;
+import org.apache.iceberg.ScanTaskGroup;
+import org.apache.iceberg.StructLike;
+import org.apache.iceberg.hadoop.Util;
+import org.apache.iceberg.io.FileIO;
+import org.apache.iceberg.relocated.com.google.common.collect.Lists;
+import org.apache.iceberg.relocated.com.google.common.collect.Maps;
+import org.apache.iceberg.types.JavaHash;
+import org.apache.iceberg.util.Tasks;
+import org.apache.iceberg.util.ThreadPools;
+
+class SparkPlanningUtil {
+
+ public static final String[] NO_LOCATION_PREFERENCE = new String[0];
+
+ private SparkPlanningUtil() {}
+
+ public static String[][] fetchBlockLocations(
+ FileIO io, List<? extends ScanTaskGroup<?>> taskGroups) {
+ String[][] locations = new String[taskGroups.size()][];
+
+ Tasks.range(taskGroups.size())
+ .stopOnFailure()
+ .executeWith(ThreadPools.getWorkerPool())
+ .run(index -> locations[index] = Util.blockLocations(io,
taskGroups.get(index)));
+
+ return locations;
+ }
+
+ public static String[][] assignExecutors(
+ List<? extends ScanTaskGroup<?>> taskGroups, List<String>
executorLocations) {
+ Map<Integer, JavaHash<StructLike>> partitionHashes = Maps.newHashMap();
+ String[][] locations = new String[taskGroups.size()][];
+
+ for (int index = 0; index < taskGroups.size(); index++) {
+ locations[index] = assign(taskGroups.get(index), executorLocations,
partitionHashes);
+ }
+
+ return locations;
+ }
+
+ private static String[] assign(
+ ScanTaskGroup<?> taskGroup,
+ List<String> executorLocations,
+ Map<Integer, JavaHash<StructLike>> partitionHashes) {
+ List<String> locations = Lists.newArrayList();
+
+ for (ScanTask task : taskGroup.tasks()) {
+ if (task.isFileScanTask()) {
+ FileScanTask fileTask = task.asFileScanTask();
+ PartitionSpec spec = fileTask.spec();
+ if (spec.isPartitioned() && !fileTask.deletes().isEmpty()) {
+ JavaHash<StructLike> partitionHash =
+ partitionHashes.computeIfAbsent(spec.specId(), key ->
partitionHash(spec));
+ int partitionHashCode = partitionHash.hash(fileTask.partition());
+ int index = Math.floorMod(partitionHashCode,
executorLocations.size());
+ String executorLocation = executorLocations.get(index);
+ locations.add(executorLocation);
+ }
+ }
+ }
+
+ return locations.toArray(NO_LOCATION_PREFERENCE);
+ }
+
+ private static JavaHash<StructLike> partitionHash(PartitionSpec spec) {
+ return JavaHash.forType(spec.partitionType());
+ }
+}
diff --git
a/spark/v3.5/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkPlanningUtil.java
b/spark/v3.5/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkPlanningUtil.java
new file mode 100644
index 0000000000..65c6790e5b
--- /dev/null
+++
b/spark/v3.5/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkPlanningUtil.java
@@ -0,0 +1,213 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.iceberg.spark.source;
+
+import static org.apache.iceberg.types.Types.NestedField.required;
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.mockito.Mockito.when;
+
+import java.util.List;
+import org.apache.iceberg.BaseScanTaskGroup;
+import org.apache.iceberg.DataFile;
+import org.apache.iceberg.DataTask;
+import org.apache.iceberg.DeleteFile;
+import org.apache.iceberg.MockFileScanTask;
+import org.apache.iceberg.PartitionSpec;
+import org.apache.iceberg.ScanTask;
+import org.apache.iceberg.ScanTaskGroup;
+import org.apache.iceberg.Schema;
+import org.apache.iceberg.StructLike;
+import org.apache.iceberg.TestHelpers.Row;
+import org.apache.iceberg.io.CloseableIterable;
+import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList;
+import org.apache.iceberg.spark.TestBaseWithCatalog;
+import org.apache.iceberg.types.Types;
+import org.junit.jupiter.api.TestTemplate;
+import org.mockito.Mockito;
+
+public class TestSparkPlanningUtil extends TestBaseWithCatalog {
+
+ private static final Schema SCHEMA =
+ new Schema(
+ required(1, "id", Types.IntegerType.get()),
+ required(2, "data", Types.StringType.get()),
+ required(3, "category", Types.StringType.get()));
+ private static final PartitionSpec SPEC_1 =
+ PartitionSpec.builderFor(SCHEMA).withSpecId(1).bucket("id",
16).identity("data").build();
+ private static final PartitionSpec SPEC_2 =
+ PartitionSpec.builderFor(SCHEMA).withSpecId(2).identity("data").build();
+ private static final List<String> EXECUTOR_LOCATIONS =
+ ImmutableList.of("host1_exec1", "host1_exec2", "host1_exec3",
"host2_exec1", "host2_exec2");
+
+ @TestTemplate
+ public void testFileScanTaskWithoutDeletes() {
+ List<ScanTask> tasks =
+ ImmutableList.of(
+ new MockFileScanTask(mockDataFile(Row.of(1, "a")), SCHEMA, SPEC_1),
+ new MockFileScanTask(mockDataFile(Row.of(2, "b")), SCHEMA, SPEC_1),
+ new MockFileScanTask(mockDataFile(Row.of(3, "c")), SCHEMA,
SPEC_1));
+ ScanTaskGroup<ScanTask> taskGroup = new BaseScanTaskGroup<>(tasks);
+ List<ScanTaskGroup<ScanTask>> taskGroups = ImmutableList.of(taskGroup);
+
+ String[][] locations = SparkPlanningUtil.assignExecutors(taskGroups,
EXECUTOR_LOCATIONS);
+
+ // should not assign executors if there are no deletes
+ assertThat(locations.length).isEqualTo(1);
+ assertThat(locations[0]).isEmpty();
+ }
+
+ @TestTemplate
+ public void testFileScanTaskWithDeletes() {
+ StructLike partition1 = Row.of("k2", null);
+ StructLike partition2 = Row.of("k1");
+ List<ScanTask> tasks =
+ ImmutableList.of(
+ new MockFileScanTask(
+ mockDataFile(partition1), mockDeleteFiles(1, partition1),
SCHEMA, SPEC_1),
+ new MockFileScanTask(
+ mockDataFile(partition2), mockDeleteFiles(3, partition2),
SCHEMA, SPEC_2),
+ new MockFileScanTask(
+ mockDataFile(partition1), mockDeleteFiles(2, partition1),
SCHEMA, SPEC_1));
+ ScanTaskGroup<ScanTask> taskGroup = new BaseScanTaskGroup<>(tasks);
+ List<ScanTaskGroup<ScanTask>> taskGroups = ImmutableList.of(taskGroup);
+
+ String[][] locations = SparkPlanningUtil.assignExecutors(taskGroups,
EXECUTOR_LOCATIONS);
+
+ // should assign executors and handle different size of partitions
+ assertThat(locations.length).isEqualTo(1);
+ assertThat(locations[0].length).isGreaterThanOrEqualTo(1);
+ }
+
+ @TestTemplate
+ public void testFileScanTaskWithUnpartitionedDeletes() {
+ List<ScanTask> tasks1 =
+ ImmutableList.of(
+ new MockFileScanTask(
+ mockDataFile(Row.of()),
+ mockDeleteFiles(2, Row.of()),
+ SCHEMA,
+ PartitionSpec.unpartitioned()),
+ new MockFileScanTask(
+ mockDataFile(Row.of()),
+ mockDeleteFiles(2, Row.of()),
+ SCHEMA,
+ PartitionSpec.unpartitioned()),
+ new MockFileScanTask(
+ mockDataFile(Row.of()),
+ mockDeleteFiles(2, Row.of()),
+ SCHEMA,
+ PartitionSpec.unpartitioned()));
+ ScanTaskGroup<ScanTask> taskGroup1 = new BaseScanTaskGroup<>(tasks1);
+ List<ScanTask> tasks2 =
+ ImmutableList.of(
+ new MockFileScanTask(
+ mockDataFile(null),
+ mockDeleteFiles(2, null),
+ SCHEMA,
+ PartitionSpec.unpartitioned()),
+ new MockFileScanTask(
+ mockDataFile(null),
+ mockDeleteFiles(2, null),
+ SCHEMA,
+ PartitionSpec.unpartitioned()),
+ new MockFileScanTask(
+ mockDataFile(null),
+ mockDeleteFiles(2, null),
+ SCHEMA,
+ PartitionSpec.unpartitioned()));
+ ScanTaskGroup<ScanTask> taskGroup2 = new BaseScanTaskGroup<>(tasks2);
+ List<ScanTaskGroup<ScanTask>> taskGroups = ImmutableList.of(taskGroup1,
taskGroup2);
+
+ String[][] locations = SparkPlanningUtil.assignExecutors(taskGroups,
EXECUTOR_LOCATIONS);
+
+ // should not assign executors if the table is unpartitioned
+ assertThat(locations.length).isEqualTo(2);
+ assertThat(locations[0]).isEmpty();
+ assertThat(locations[1]).isEmpty();
+ }
+
+ @TestTemplate
+ public void testDataTasks() {
+ List<ScanTask> tasks =
+ ImmutableList.of(
+ new MockDataTask(mockDataFile(Row.of(1, "a"))),
+ new MockDataTask(mockDataFile(Row.of(2, "b"))),
+ new MockDataTask(mockDataFile(Row.of(3, "c"))));
+ ScanTaskGroup<ScanTask> taskGroup = new BaseScanTaskGroup<>(tasks);
+ List<ScanTaskGroup<ScanTask>> taskGroups = ImmutableList.of(taskGroup);
+
+ String[][] locations = SparkPlanningUtil.assignExecutors(taskGroups,
EXECUTOR_LOCATIONS);
+
+ // should not assign executors for data tasks
+ assertThat(locations.length).isEqualTo(1);
+ assertThat(locations[0]).isEmpty();
+ }
+
+ @TestTemplate
+ public void testUnknownTasks() {
+ List<ScanTask> tasks = ImmutableList.of(new UnknownScanTask(), new
UnknownScanTask());
+ ScanTaskGroup<ScanTask> taskGroup = new BaseScanTaskGroup<>(tasks);
+ List<ScanTaskGroup<ScanTask>> taskGroups = ImmutableList.of(taskGroup);
+
+ String[][] locations = SparkPlanningUtil.assignExecutors(taskGroups,
EXECUTOR_LOCATIONS);
+
+ // should not assign executors for unknown tasks
+ assertThat(locations.length).isEqualTo(1);
+ assertThat(locations[0]).isEmpty();
+ }
+
+ private static DataFile mockDataFile(StructLike partition) {
+ DataFile file = Mockito.mock(DataFile.class);
+ when(file.partition()).thenReturn(partition);
+ return file;
+ }
+
+ private static DeleteFile[] mockDeleteFiles(int count, StructLike partition)
{
+ DeleteFile[] files = new DeleteFile[count];
+ for (int index = 0; index < count; index++) {
+ files[index] = mockDeleteFile(partition);
+ }
+ return files;
+ }
+
+ private static DeleteFile mockDeleteFile(StructLike partition) {
+ DeleteFile file = Mockito.mock(DeleteFile.class);
+ when(file.partition()).thenReturn(partition);
+ return file;
+ }
+
+ private static class MockDataTask extends MockFileScanTask implements
DataTask {
+
+ MockDataTask(DataFile file) {
+ super(file);
+ }
+
+ @Override
+ public PartitionSpec spec() {
+ return PartitionSpec.unpartitioned();
+ }
+
+ @Override
+ public CloseableIterable<StructLike> rows() {
+ throw new UnsupportedOperationException();
+ }
+ }
+
+ private static class UnknownScanTask implements ScanTask {}
+}