This is an automated email from the ASF dual-hosted git repository.
jin pushed a commit to branch master
in repository
https://gitbox.apache.org/repos/asf/incubator-hugegraph-toolchain.git
The following commit(s) were added to refs/heads/master by this push:
new f58194b6 feat(loader): update loading logic & enhance failure case
(#704)
f58194b6 is described below
commit f58194b69dc990bcf5fde6becc74ef25873a8c94
Author: Ken <[email protected]>
AuthorDate: Wed Jan 21 22:44:33 2026 +0800
feat(loader): update loading logic & enhance failure case (#704)
* Create auto-pr-review.yml (#1)
* [feat] Adjusted several default parameters and descriptions in the
Loader, and refactored the failure-handling logic for batch inserts.
* [feat] Standardize log levels and streamline parameter descriptions.
* Update required review count and collaborators
---------
Co-authored-by: imbajin <[email protected]>
---
.asf.yaml | 6 +-
.../apache/hugegraph/loader/HugeGraphLoader.java | 11 +-
.../hugegraph/loader/executor/LoadOptions.java | 53 +-
.../apache/hugegraph/loader/task/TaskManager.java | 21 +-
.../loader/test/functional/FileLoadTest.java | 27 +-
.../loader/test/functional/KafkaLoadTest.java | 3 +-
.../loader/test/unit/LoadOptionsTest.java | 213 ++++++++
.../loader/test/unit/TaskManagerTest.java | 536 +++++++++++++++++++++
8 files changed, 838 insertions(+), 32 deletions(-)
diff --git a/.asf.yaml b/.asf.yaml
index 83de7215..1d9d56e2 100644
--- a/.asf.yaml
+++ b/.asf.yaml
@@ -49,14 +49,12 @@ github:
required_pull_request_reviews:
dismiss_stale_reviews: true
require_code_owner_reviews: false
- required_approving_review_count: 2
+ required_approving_review_count: 1
# (for non-committer): assign/edit/close issues & PR, without write access
to the code
collaborators:
- - Pengzna
+ - kenssa4eedfd
- haohao0103
- - Thespica
- FrostyHec
- - MuLeiSY2021
notifications:
# use https://selfserve.apache.org to manage it
diff --git
a/hugegraph-loader/src/main/java/org/apache/hugegraph/loader/HugeGraphLoader.java
b/hugegraph-loader/src/main/java/org/apache/hugegraph/loader/HugeGraphLoader.java
index 2fb9eb4a..dab3fb57 100644
---
a/hugegraph-loader/src/main/java/org/apache/hugegraph/loader/HugeGraphLoader.java
+++
b/hugegraph-loader/src/main/java/org/apache/hugegraph/loader/HugeGraphLoader.java
@@ -662,24 +662,21 @@ public final class HugeGraphLoader {
}
private void loadStructs(List<InputStruct> structs) {
- int parallelCount = this.context.options().parallelCount;
+ int parseThreads = this.context.options().parseThreads;
if (structs.size() == 0) {
return;
}
- if (parallelCount <= 0) {
- parallelCount = Math.min(structs.size(),
Runtime.getRuntime().availableProcessors() * 2);
- }
boolean scatter = this.context.options().scatterSources;
- LOG.info("{} threads for loading {} structs, from {} to {} in {} mode",
- parallelCount, structs.size(),
this.context.options().startFile,
+ LOG.info("{} parser threads for loading {} structs, from {} to {} in
{} mode",
+ parseThreads, structs.size(),
this.context.options().startFile,
this.context.options().endFile,
scatter ? "scatter" : "sequential");
ExecutorService loadService = null;
try {
- loadService = ExecutorUtil.newFixedThreadPool(parallelCount,
"loader");
+ loadService = ExecutorUtil.newFixedThreadPool(parseThreads,
"loader");
List<InputTaskItem> taskItems = prepareTaskItems(structs, scatter);
List<CompletableFuture<Void>> loadTasks = new ArrayList<>();
diff --git
a/hugegraph-loader/src/main/java/org/apache/hugegraph/loader/executor/LoadOptions.java
b/hugegraph-loader/src/main/java/org/apache/hugegraph/loader/executor/LoadOptions.java
index 95babb55..f0ea30b7 100644
---
a/hugegraph-loader/src/main/java/org/apache/hugegraph/loader/executor/LoadOptions.java
+++
b/hugegraph-loader/src/main/java/org/apache/hugegraph/loader/executor/LoadOptions.java
@@ -20,6 +20,7 @@ package org.apache.hugegraph.loader.executor;
import java.io.File;
import java.lang.reflect.Field;
import java.util.ArrayList;
+import java.util.Arrays;
import java.util.List;
import java.util.Set;
@@ -45,6 +46,8 @@ public final class LoadOptions implements Cloneable {
public static final String HTTPS_SCHEMA = "https";
public static final String HTTP_SCHEMA = "http";
private static final int CPUS = Runtime.getRuntime().availableProcessors();
+ private static final int DEFAULT_MAX_CONNECTIONS = CPUS * 4;
+ private static final int DEFAULT_MAX_CONNECTIONS_PER_ROUTE = CPUS * 2;
private static final int MINIMUM_REQUIRED_ARGS = 3;
@Parameter(names = {"-f", "--file"}, required = true, arity = 1,
@@ -156,7 +159,7 @@ public final class LoadOptions implements Cloneable {
@Parameter(names = {"--batch-insert-threads"}, arity = 1,
validateWith = {PositiveValidator.class},
- description = "The number of threads to execute batch insert")
+ description = "The number of threads to execute batch insert
(default: CPUS)")
public int batchInsertThreads = CPUS;
@Parameter(names = {"--single-insert-threads"}, arity = 1,
@@ -165,21 +168,27 @@ public final class LoadOptions implements Cloneable {
public int singleInsertThreads = 8;
@Parameter(names = {"--max-conn"}, arity = 1,
- description = "Max number of HTTP connections to server")
- public int maxConnections = CPUS * 4;
+ validateWith = {PositiveValidator.class},
+ description = "Max HTTP connections (default: CPUS*4;
auto-adjusted by " +
+ "--batch-insert-threads)")
+ public int maxConnections = DEFAULT_MAX_CONNECTIONS;
@Parameter(names = {"--max-conn-per-route"}, arity = 1,
- description = "Max number of HTTP connections to each route")
- public int maxConnectionsPerRoute = CPUS * 2;
+ validateWith = {PositiveValidator.class},
+ description = "Max HTTP connections per route (default: CPUS*2;
" +
+ "auto-adjusted by --batch-insert-threads)")
+ public int maxConnectionsPerRoute = DEFAULT_MAX_CONNECTIONS_PER_ROUTE;
@Parameter(names = {"--batch-size"}, arity = 1,
validateWith = {PositiveValidator.class},
description = "The number of lines in each submit")
public int batchSize = 500;
- @Parameter(names = {"--parallel-count"}, arity = 1,
- description = "The number of parallel read pipelines")
- public int parallelCount = 1;
+ @Parameter(names = {"--parallel-count", "--parser-threads"}, arity = 1,
+ validateWith = {PositiveValidator.class},
+ description = "Parallel read pipelines (default: max(2,
CPUS/2); " +
+ "--parallel-count is deprecated)")
+ public int parseThreads = Math.max(2, CPUS / 2);
@Parameter(names = {"--start-file"}, arity = 1,
description = "start file index for partial loading")
@@ -329,6 +338,11 @@ public final class LoadOptions implements Cloneable {
description = "The task scheduler type (when creating graph if
not exists")
public String schedulerType = "distributed";
+ @Parameter(names = {"--batch-failure-fallback"}, arity = 1,
+ description = "Whether to fallback to single insert when batch
insert fails. " +
+ "Default: true")
+ public boolean batchFailureFallback = true;
+
public String workModeString() {
if (this.incrementalMode) {
return "INCREMENTAL MODE";
@@ -406,9 +420,32 @@ public final class LoadOptions implements Cloneable {
options.maxParseErrors = Constants.NO_LIMIT;
options.maxInsertErrors = Constants.NO_LIMIT;
}
+ if (Arrays.asList(args).contains("--parallel-count")) {
+ LOG.warn("Parameter --parallel-count is deprecated, " +
+ "please use --parser-threads instead");
+ }
+ adjustConnectionPoolIfDefault(options);
return options;
}
+ private static void adjustConnectionPoolIfDefault(LoadOptions options) {
+ int batchThreads = options.batchInsertThreads;
+ int maxConn = options.maxConnections;
+ int maxConnPerRoute = options.maxConnectionsPerRoute;
+
+ if (maxConn == DEFAULT_MAX_CONNECTIONS && maxConn < batchThreads * 4) {
+ options.maxConnections = batchThreads * 4;
+ LOG.info("Auto adjusted max-conn to {} based on
batch-insert-threads({})",
+ options.maxConnections, batchThreads);
+ }
+
+ if (maxConnPerRoute == DEFAULT_MAX_CONNECTIONS_PER_ROUTE &&
maxConnPerRoute < batchThreads * 2) {
+ options.maxConnectionsPerRoute = batchThreads * 2;
+ LOG.info("Auto adjusted max-conn-per-route to {} based on
batch-insert-threads({})",
+ options.maxConnectionsPerRoute, batchThreads);
+ }
+ }
+
public ShortIdConfig getShortIdConfig(String vertexLabel) {
for (ShortIdConfig config: shorterIDConfigs) {
if (config.getVertexLabel().equals(vertexLabel)) {
diff --git
a/hugegraph-loader/src/main/java/org/apache/hugegraph/loader/task/TaskManager.java
b/hugegraph-loader/src/main/java/org/apache/hugegraph/loader/task/TaskManager.java
index ce4d77a9..7d079395 100644
---
a/hugegraph-loader/src/main/java/org/apache/hugegraph/loader/task/TaskManager.java
+++
b/hugegraph-loader/src/main/java/org/apache/hugegraph/loader/task/TaskManager.java
@@ -26,6 +26,7 @@ import java.util.concurrent.ExecutorService;
import java.util.concurrent.Semaphore;
import java.util.concurrent.TimeUnit;
+import org.apache.hugegraph.loader.util.Printer;
import org.slf4j.Logger;
import org.apache.hugegraph.loader.builder.Record;
@@ -140,6 +141,10 @@ public final class TaskManager {
long start = System.currentTimeMillis();
try {
this.batchSemaphore.acquire();
+ if (this.context.stopped()) {
+ this.batchSemaphore.release();
+ return;
+ }
} catch (InterruptedException e) {
throw new LoadException("Interrupted while waiting to submit %s " +
"batch in batch mode", e, mapping.type());
@@ -152,10 +157,18 @@ public final class TaskManager {
CompletableFuture.runAsync(task, this.batchService).whenComplete(
(r, e) -> {
if (e != null) {
- LOG.warn("Batch insert {} error, try single insert",
- mapping.type(), e);
- // The time of single insert is counted separately
- this.submitInSingle(struct, mapping, batch);
+ if (this.options.batchFailureFallback) {
+ LOG.warn("Batch insert {} error, try single insert",
+ mapping.type(), e);
+ this.submitInSingle(struct, mapping, batch);
+ } else {
+ summary.metrics(struct).minusFlighting(batch.size());
+ this.context.occurredError();
+ this.context.stopLoading();
+ LOG.error("Batch insert {} error, interrupting
import", mapping.type(), e);
+ Printer.printError("Batch insert %s failed, stop
loading. Please check the logs",
+ mapping.type().string());
+ }
} else {
summary.metrics(struct).minusFlighting(batch.size());
}
diff --git
a/hugegraph-loader/src/test/java/org/apache/hugegraph/loader/test/functional/FileLoadTest.java
b/hugegraph-loader/src/test/java/org/apache/hugegraph/loader/test/functional/FileLoadTest.java
index d069aaec..5be6a61e 100644
---
a/hugegraph-loader/src/test/java/org/apache/hugegraph/loader/test/functional/FileLoadTest.java
+++
b/hugegraph-loader/src/test/java/org/apache/hugegraph/loader/test/functional/FileLoadTest.java
@@ -1197,7 +1197,9 @@ public class FileLoadTest extends LoadTest {
"-s", configPath("multi_files_have_header/schema.groovy"),
"-g", GRAPH,
"-h", SERVER,
- "--test-mode", "true"
+ "--test-mode", "true",
+ // FIXME: Set parser-threads to 1 because values > 1 currently
trigger a NullPointerException (NPE).
+ "--parser-threads", "1"
};
loadWithAuth(args);
@@ -1332,7 +1334,9 @@ public class FileLoadTest extends LoadTest {
"-s", configPath("dir_has_multi_files/schema.groovy"),
"-g", GRAPH,
"-h", SERVER,
- "--test-mode", "true"
+ "--test-mode", "true",
+ // FIXME: Set parser-threads to 1 because values > 1 currently
trigger a NullPointerException (NPE).
+ "--parser-threads", "1"
};
loadWithAuth(args);
@@ -1628,7 +1632,9 @@ public class FileLoadTest extends LoadTest {
"-s", configPath("filter_path_by_suffix/schema.groovy"),
"-g", GRAPH,
"-h", SERVER,
- "--test-mode", "true"
+ "--test-mode", "true",
+ // FIXME: Set parser-threads to 1 because values > 1 currently
trigger a NullPointerException (NPE).
+ "--parser-threads", "1"
};
loadWithAuth(args);
@@ -2058,7 +2064,8 @@ public class FileLoadTest extends LoadTest {
"-h", SERVER,
"--batch-insert-threads", "2",
"--max-parse-errors", "1",
- "--test-mode", "false"
+ "--test-mode", "false",
+ "--parser-threads", "1"
));
argsList.addAll(Arrays.asList("--username", "admin", "--password",
"pa"));
@@ -2259,7 +2266,8 @@ public class FileLoadTest extends LoadTest {
"-h", SERVER,
"--check-vertex", "true",
"--batch-insert-threads", "2",
- "--test-mode", "false"
+ "--test-mode", "false",
+ "--parser-threads", "1"
));
argsList.addAll(Arrays.asList("--username", "admin", "--password",
"pa"));
HugeGraphLoader loader = new HugeGraphLoader(argsList.toArray(new
String[0]));
@@ -2564,7 +2572,8 @@ public class FileLoadTest extends LoadTest {
"-g", GRAPH,
"-h", SERVER,
"--batch-insert-threads", "2",
- "--test-mode", "true"
+ "--test-mode", "true",
+ "--parser-threads", "1"
));
argsList.addAll(Arrays.asList("--username", "admin", "--password",
"pa"));
@@ -3047,7 +3056,8 @@ public class FileLoadTest extends LoadTest {
"-h", SERVER,
"--max-read-lines", "4",
"--batch-insert-threads", "2",
- "--test-mode", "true"
+ "--test-mode", "true",
+ "--parser-threads", "1"
};
loadWithAuth(args);
@@ -3061,7 +3071,8 @@ public class FileLoadTest extends LoadTest {
"-h", SERVER,
"--max-read-lines", "6",
"--batch-insert-threads", "2",
- "--test-mode", "true"
+ "--test-mode", "true",
+ "--parser-threads", "1"
};
loadWithAuth(args);
diff --git
a/hugegraph-loader/src/test/java/org/apache/hugegraph/loader/test/functional/KafkaLoadTest.java
b/hugegraph-loader/src/test/java/org/apache/hugegraph/loader/test/functional/KafkaLoadTest.java
index c6c31520..b44ffbd8 100644
---
a/hugegraph-loader/src/test/java/org/apache/hugegraph/loader/test/functional/KafkaLoadTest.java
+++
b/hugegraph-loader/src/test/java/org/apache/hugegraph/loader/test/functional/KafkaLoadTest.java
@@ -85,7 +85,8 @@ public class KafkaLoadTest extends LoadTest {
"-h", SERVER,
"-p", String.valueOf(PORT),
"--batch-insert-threads", "2",
- "--test-mode", "true"
+ "--test-mode", "true",
+ "--parser-threads", "1"
};
loadWithAuth(args);
diff --git
a/hugegraph-loader/src/test/java/org/apache/hugegraph/loader/test/unit/LoadOptionsTest.java
b/hugegraph-loader/src/test/java/org/apache/hugegraph/loader/test/unit/LoadOptionsTest.java
new file mode 100644
index 00000000..b327f59e
--- /dev/null
+++
b/hugegraph-loader/src/test/java/org/apache/hugegraph/loader/test/unit/LoadOptionsTest.java
@@ -0,0 +1,213 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with this
+ * work for additional information regarding copyright ownership. The ASF
+ * licenses this file to You under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+ * License for the specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.hugegraph.loader.test.unit;
+
+import java.io.File;
+import java.io.FileWriter;
+import java.lang.reflect.Field;
+import java.lang.reflect.Method;
+
+import org.apache.log4j.AppenderSkeleton;
+import org.apache.log4j.Level;
+import org.apache.log4j.Logger;
+import org.apache.log4j.spi.LoggingEvent;
+
+import org.apache.hugegraph.loader.executor.LoadOptions;
+import org.junit.Test;
+
+import org.apache.hugegraph.testutil.Assert;
+
+public class LoadOptionsTest {
+
+ @Test
+ public void testConnectionPoolAutoAdjustWithDefaultBatchThreads() throws
Exception {
+ int cpus = readStaticInt(LoadOptions.class, "CPUS");
+ LoadOptions options = new LoadOptions();
+
+ Assert.assertEquals(cpus * 4, options.maxConnections);
+ Assert.assertEquals(cpus * 2, options.maxConnectionsPerRoute);
+ }
+
+ @Test
+ public void testConnectionPoolAutoAdjustWithCustomBatchThreads() throws
Exception {
+ int cpus = readStaticInt(LoadOptions.class, "CPUS");
+ int defaultMaxConn = readStaticInt(LoadOptions.class,
"DEFAULT_MAX_CONNECTIONS");
+ int defaultMaxConnPerRoute = readStaticInt(LoadOptions.class,
+
"DEFAULT_MAX_CONNECTIONS_PER_ROUTE");
+ LoadOptions options = new LoadOptions();
+ options.batchInsertThreads = 20;
+
+ CapturingAppender appender = attachAppender();
+ try {
+ invokeAdjustConnectionPool(options);
+ } finally {
+ detachAppender(appender);
+ }
+
+ int expectedMaxConn = defaultMaxConn;
+ int expectedMaxConnPerRoute = defaultMaxConnPerRoute;
+ if (defaultMaxConn == cpus * 4 && defaultMaxConn < 80) {
+ expectedMaxConn = 80;
+ }
+ if (defaultMaxConnPerRoute == cpus * 2 && defaultMaxConnPerRoute < 40)
{
+ expectedMaxConnPerRoute = 40;
+ }
+
+ Assert.assertEquals(expectedMaxConn, options.maxConnections);
+ Assert.assertEquals(expectedMaxConnPerRoute,
options.maxConnectionsPerRoute);
+ if (expectedMaxConn == 80 || expectedMaxConnPerRoute == 40) {
+ Assert.assertTrue(appender.contains("Auto adjusted max-conn"));
+ }
+ }
+
+ @Test
+ public void testConnectionPoolNoAdjustWithCustomMaxConn() throws Exception
{
+ LoadOptions options = new LoadOptions();
+ options.batchInsertThreads = 20;
+ options.maxConnections = 100;
+ options.maxConnectionsPerRoute = 50;
+
+ CapturingAppender appender = attachAppender();
+ try {
+ invokeAdjustConnectionPool(options);
+ } finally {
+ detachAppender(appender);
+ }
+
+ Assert.assertEquals(100, options.maxConnections);
+ Assert.assertEquals(50, options.maxConnectionsPerRoute);
+ Assert.assertFalse(appender.contains("Auto adjusted max-conn"));
+ }
+
+ @Test
+ public void testParseThreadsMinValue() {
+ LoadOptions.PositiveValidator validator =
+ new LoadOptions.PositiveValidator();
+
+ validator.validate("--parser-threads", "1");
+
+ Assert.assertTrue(validateFails(validator, "--parser-threads", "0"));
+ Assert.assertTrue(validateFails(validator, "--parser-threads", "-1"));
+ }
+
+ @Test
+ public void testParseThreadsDefaultValue() throws Exception {
+ int cpus = readStaticInt(LoadOptions.class, "CPUS");
+ LoadOptions options = new LoadOptions();
+ Assert.assertEquals(Math.max(2, cpus / 2), options.parseThreads);
+ }
+
+ @Test
+ public void testDeprecatedParallelCountParameter() throws Exception {
+ File mapping = createTempMapping();
+ String[] args = new String[]{
+ "-f", mapping.getPath(),
+ "-g", "g",
+ "-h", "localhost",
+ "--parallel-count", "4"
+ };
+
+ CapturingAppender appender = attachAppender();
+ try {
+ LoadOptions options = LoadOptions.parseOptions(args);
+ Assert.assertEquals(4, options.parseThreads);
+ Assert.assertTrue(appender.contains("deprecated"));
+ } finally {
+ detachAppender(appender);
+ mapping.delete();
+ }
+ }
+
+ private static int readStaticInt(Class<?> type, String name)
+ throws Exception {
+ Field field = type.getDeclaredField(name);
+ field.setAccessible(true);
+ return field.getInt(null);
+ }
+
+ private static void invokeAdjustConnectionPool(LoadOptions options)
+ throws Exception {
+ Method method = LoadOptions.class
+
.getDeclaredMethod("adjustConnectionPoolIfDefault",
+ LoadOptions.class);
+ method.setAccessible(true);
+ method.invoke(null, options);
+ }
+
+ private static boolean validateFails(LoadOptions.PositiveValidator
validator,
+ String name, String value) {
+ try {
+ validator.validate(name, value);
+ return false;
+ } catch (Exception ignored) {
+ return true;
+ }
+ }
+
+ private static File createTempMapping() throws Exception {
+ File file = File.createTempFile("load-options-", ".json", new
File("."));
+ try (FileWriter writer = new FileWriter(file)) {
+ writer.write("{\"version\":\"2.0\",\"structs\":[]}");
+ }
+ return file;
+ }
+
+ private static CapturingAppender attachAppender() {
+ Logger logger = Logger.getLogger(LoadOptions.class.getName());
+ CapturingAppender appender = new CapturingAppender();
+ appender.setThreshold(Level.INFO);
+ logger.addAppender(appender);
+ return appender;
+ }
+
+ private static void detachAppender(CapturingAppender appender) {
+ if (appender == null) {
+ return;
+ }
+ Logger logger = Logger.getLogger(LoadOptions.class.getName());
+ logger.removeAppender(appender);
+ }
+
+ private static final class CapturingAppender extends AppenderSkeleton {
+
+ private final StringBuilder buffer = new StringBuilder();
+
+ @Override
+ protected void append(LoggingEvent event) {
+ if (event == null || event.getRenderedMessage() == null) {
+ return;
+ }
+ buffer.append(event.getRenderedMessage()).append('\n');
+ }
+
+ boolean contains(String text) {
+ return this.buffer.toString().contains(text);
+ }
+
+ @Override
+ public void close() {
+ // No-op.
+ }
+
+ @Override
+ public boolean requiresLayout() {
+ return false;
+ }
+ }
+
+}
diff --git
a/hugegraph-loader/src/test/java/org/apache/hugegraph/loader/test/unit/TaskManagerTest.java
b/hugegraph-loader/src/test/java/org/apache/hugegraph/loader/test/unit/TaskManagerTest.java
new file mode 100644
index 00000000..08b6b904
--- /dev/null
+++
b/hugegraph-loader/src/test/java/org/apache/hugegraph/loader/test/unit/TaskManagerTest.java
@@ -0,0 +1,536 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with this
+ * work for additional information regarding copyright ownership. The ASF
+ * licenses this file to You under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+ * License for the specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.hugegraph.loader.test.unit;
+
+import java.io.ByteArrayOutputStream;
+import java.io.PrintStream;
+import java.lang.reflect.Field;
+import java.lang.reflect.Method;
+import java.nio.charset.StandardCharsets;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.concurrent.Future;
+import java.util.concurrent.Semaphore;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicInteger;
+import java.util.concurrent.atomic.LongAdder;
+
+import org.apache.hugegraph.driver.GraphManager;
+import org.apache.hugegraph.driver.HugeClient;
+import org.apache.hugegraph.loader.builder.Record;
+import org.apache.hugegraph.loader.executor.LoadContext;
+import org.apache.hugegraph.loader.executor.LoadOptions;
+import org.apache.hugegraph.loader.mapping.EdgeMapping;
+import org.apache.hugegraph.loader.mapping.InputStruct;
+import org.apache.hugegraph.loader.metrics.LoadMetrics;
+import org.apache.hugegraph.loader.metrics.LoadSummary;
+import org.apache.hugegraph.loader.progress.LoadProgress;
+import org.apache.hugegraph.loader.task.TaskManager;
+import org.apache.hugegraph.structure.graph.Edge;
+import org.junit.Test;
+
+import org.apache.hugegraph.testutil.Assert;
+
+public class TaskManagerTest {
+
+ @Test
+ public void testBatchInsertFailureWithFallbackDisabled() throws Exception {
+ LoadOptions options = new LoadOptions();
+ options.batchFailureFallback = false;
+ Assert.assertFalse(options.batchFailureFallback);
+
+ LoadContext context = newTestContext(options);
+ TaskManager taskManager = new TaskManager(context);
+
+ EdgeMapping mapping = new EdgeMapping(Arrays.asList("s"), false,
+ Arrays.asList("t"), false);
+ mapping.label("knows");
+
+ InputStruct struct = new InputStruct(new ArrayList<>(),
+ new ArrayList<>());
+ struct.id("1");
+ struct.add(mapping);
+
+ LoadSummary summary = context.summary();
+ summary.inputMetricsMap()
+ .put(struct.id(), new LoadMetrics(struct));
+ LoadMetrics metrics = summary.metrics(struct);
+
+ setField(context.client(), "graph", newFailingBatchGraphManager());
+
+ List<Record> batch = new ArrayList<>();
+ batch.add(new Record("line1", new Edge("knows")));
+ batch.add(new Record("line2", new Edge("knows")));
+
+ ByteArrayOutputStream errOutput = new ByteArrayOutputStream();
+ PrintStream originalErr = System.err;
+ System.setErr(new PrintStream(errOutput, true,
+ StandardCharsets.UTF_8.name()));
+ try {
+ taskManager.submitBatch(struct, mapping, batch);
+ taskManager.waitFinished();
+
+ Assert.assertEquals(0L, flightingCount(metrics));
+ Assert.assertTrue(context.stopped());
+ Assert.assertFalse(context.noError());
+
+ String errText = errOutput.toString(StandardCharsets.UTF_8.name());
+ Assert.assertTrue(errText.contains(
+ "Batch insert edges failed, stop loading."));
+
+ long before = flightingCount(metrics);
+ taskManager.submitBatch(struct, mapping, batch);
+ taskManager.waitFinished();
+ Assert.assertEquals(before, flightingCount(metrics));
+ } finally {
+ System.setErr(originalErr);
+ taskManager.shutdown();
+ }
+ }
+
+ @Test
+ public void testBatchInsertFailureWithFallbackEnabled() throws Exception {
+ LoadOptions options = new LoadOptions();
+ options.batchFailureFallback = true;
+
+ LoadContext context = newTestContext(options);
+ TaskManager taskManager = new TaskManager(context);
+
+ EdgeMapping mapping = new EdgeMapping(Arrays.asList("s"), false,
+ Arrays.asList("t"), false);
+ mapping.label("knows");
+
+ InputStruct struct = new InputStruct(new ArrayList<>(),
+ new ArrayList<>());
+ struct.id("1");
+ struct.add(mapping);
+
+ LoadSummary summary = context.summary();
+ summary.inputMetricsMap()
+ .put(struct.id(), new LoadMetrics(struct));
+ LoadMetrics metrics = summary.metrics(struct);
+
+ FailingBatchGraphManager.BATCH_CALLS.set(0);
+ FailingBatchGraphManager.SINGLE_CALLS.set(0);
+ setField(context.client(), "graph", newFailingBatchGraphManager());
+
+ List<Record> batch = new ArrayList<>();
+ batch.add(new Record("line1", new Edge("knows")));
+ batch.add(new Record("line2", new Edge("knows")));
+
+ try {
+ taskManager.submitBatch(struct, mapping, batch);
+ taskManager.waitFinished();
+
+ Assert.assertEquals(1, FailingBatchGraphManager.BATCH_CALLS.get());
+ Assert.assertEquals(2,
FailingBatchGraphManager.SINGLE_CALLS.get());
+ Assert.assertEquals(0L, flightingCount(metrics));
+ Assert.assertFalse(context.stopped());
+ Assert.assertTrue(context.noError());
+ } finally {
+ taskManager.shutdown();
+ }
+ }
+
+ @Test
+ public void testMultipleBatchFailuresCounterConsistency() throws Exception
{
+ LoadOptions options = new LoadOptions();
+ options.batchFailureFallback = true;
+
+ LoadContext context = newTestContext(options);
+ TaskManager taskManager = new TaskManager(context);
+
+ EdgeMapping mapping = new EdgeMapping(Arrays.asList("s"), false,
+ Arrays.asList("t"), false);
+ mapping.label("knows");
+
+ InputStruct struct = new InputStruct(new ArrayList<>(),
+ new ArrayList<>());
+ struct.id("1");
+ struct.add(mapping);
+
+ LoadSummary summary = context.summary();
+ summary.inputMetricsMap()
+ .put(struct.id(), new LoadMetrics(struct));
+ LoadMetrics metrics = summary.metrics(struct);
+
+ FailingBatchGraphManager.BATCH_CALLS.set(0);
+ FailingBatchGraphManager.SINGLE_CALLS.set(0);
+ setField(context.client(), "graph", newFailingBatchGraphManager());
+
+ List<Record> batch1 = new ArrayList<>();
+ batch1.add(new Record("line1", new Edge("knows")));
+ batch1.add(new Record("line2", new Edge("knows")));
+
+ List<Record> batch2 = new ArrayList<>();
+ batch2.add(new Record("line3", new Edge("knows")));
+ batch2.add(new Record("line4", new Edge("knows")));
+
+ try {
+ taskManager.submitBatch(struct, mapping, batch1);
+ taskManager.submitBatch(struct, mapping, batch2);
+ taskManager.waitFinished();
+
+ Assert.assertEquals(2, FailingBatchGraphManager.BATCH_CALLS.get());
+ Assert.assertEquals(4,
FailingBatchGraphManager.SINGLE_CALLS.get());
+ Assert.assertEquals(0L, flightingCount(metrics));
+ Assert.assertFalse(context.stopped());
+ Assert.assertTrue(context.noError());
+
+ int expectedBatchPermits = 1 + options.batchInsertThreads;
+ int expectedSinglePermits = 2 * options.singleInsertThreads;
+ Assert.assertEquals(expectedBatchPermits,
+ getSemaphorePermits(taskManager,
"batchSemaphore"));
+ Assert.assertEquals(expectedSinglePermits,
+ getSemaphorePermits(taskManager,
"singleSemaphore"));
+ } finally {
+ taskManager.shutdown();
+ }
+ }
+
+ @Test
+ public void testConcurrentSubmitWhenStopping() throws Exception {
+ LoadOptions options = new LoadOptions();
+ options.batchFailureFallback = false;
+ options.batchInsertThreads = 2;
+ options.singleInsertThreads = 1;
+
+ LoadContext context = newTestContext(options);
+ TaskManager taskManager = new TaskManager(context);
+
+ EdgeMapping mapping = new EdgeMapping(Arrays.asList("s"), false,
+ Arrays.asList("t"), false);
+ mapping.label("knows");
+
+ InputStruct struct = new InputStruct(new ArrayList<>(),
+ new ArrayList<>());
+ struct.id("1");
+ struct.add(mapping);
+
+ LoadSummary summary = context.summary();
+ summary.inputMetricsMap()
+ .put(struct.id(), new LoadMetrics(struct));
+ LoadMetrics metrics = summary.metrics(struct);
+
+ CountDownLatch firstStarted = new CountDownLatch(1);
+ CountDownLatch allowFirstFinish = new CountDownLatch(1);
+ CountDownLatch failureCalled = new CountDownLatch(1);
+ FailingConcurrentGraphManager.BATCH_CALLS.set(0);
+ FailingConcurrentGraphManager.FIRST_STARTED = firstStarted;
+ FailingConcurrentGraphManager.ALLOW_FIRST_FINISH = allowFirstFinish;
+ FailingConcurrentGraphManager.FAILURE_CALLED = failureCalled;
+ setField(context.client(), "graph",
newFailingConcurrentGraphManager());
+
+ List<Record> batch = new ArrayList<>();
+ batch.add(new Record("line1", new Edge("knows")));
+ batch.add(new Record("line2", new Edge("knows")));
+
+ ExecutorService executor = Executors.newFixedThreadPool(10);
+ List<Future<?>> futures = new ArrayList<>();
+ try {
+ for (int i = 0; i < 10; i++) {
+ futures.add(executor.submit(() -> {
+ taskManager.submitBatch(struct, mapping, batch);
+ }));
+ }
+
+ Assert.assertTrue(firstStarted.await(5, TimeUnit.SECONDS));
+ Assert.assertTrue(failureCalled.await(5, TimeUnit.SECONDS));
+ waitStopped(context, 5, TimeUnit.SECONDS);
+ allowFirstFinish.countDown();
+
+ for (Future<?> future : futures) {
+ future.get(5, TimeUnit.SECONDS);
+ }
+
+ taskManager.waitFinished();
+
+ int batchCalls = FailingConcurrentGraphManager.BATCH_CALLS.get();
+ Assert.assertTrue(batchCalls >= 2 && batchCalls <= 3);
+ Assert.assertEquals(0L, flightingCount(metrics));
+ Assert.assertTrue(context.stopped());
+ Assert.assertFalse(context.noError());
+
+ long before = FailingConcurrentGraphManager.BATCH_CALLS.get();
+ taskManager.submitBatch(struct, mapping, batch);
+ taskManager.waitFinished();
+ Assert.assertEquals(before,
FailingConcurrentGraphManager.BATCH_CALLS.get());
+
+ int expectedBatchPermits = 1 + options.batchInsertThreads;
+ int expectedSinglePermits = 2 * options.singleInsertThreads;
+ Assert.assertEquals(expectedBatchPermits,
+ getSemaphorePermits(taskManager,
"batchSemaphore"));
+ Assert.assertEquals(expectedSinglePermits,
+ getSemaphorePermits(taskManager,
"singleSemaphore"));
+ } finally {
+ allowFirstFinish.countDown();
+ executor.shutdownNow();
+ taskManager.shutdown();
+ }
+ }
+
+ @Test
+ public void testStopCheckTimingInSubmitBatch() throws Exception {
+ LoadOptions options = new LoadOptions();
+ options.batchFailureFallback = false;
+ options.batchInsertThreads = 1;
+ options.singleInsertThreads = 1;
+
+ LoadContext context = newTestContext(options);
+ TaskManager taskManager = new TaskManager(context);
+
+ EdgeMapping mapping = new EdgeMapping(Arrays.asList("s"), false,
+ Arrays.asList("t"), false);
+ mapping.label("knows");
+
+ InputStruct struct = new InputStruct(new ArrayList<>(),
+ new ArrayList<>());
+ struct.id("1");
+ struct.add(mapping);
+
+ LoadSummary summary = context.summary();
+ summary.inputMetricsMap()
+ .put(struct.id(), new LoadMetrics(struct));
+ LoadMetrics metrics = summary.metrics(struct);
+
+ setField(context.client(), "graph", newSimpleGraphManager());
+
+ List<Record> batch = new ArrayList<>();
+ batch.add(new Record("line1", new Edge("knows")));
+ batch.add(new Record("line2", new Edge("knows")));
+
+ ExecutorService executor = Executors.newFixedThreadPool(2);
+ try {
+ taskManager.submitBatch(struct, mapping, batch);
+ taskManager.waitFinished();
+
+ Semaphore semaphore = getSemaphore(taskManager, "batchSemaphore");
+ semaphore.acquire();
+
+ Future<?> blocked = executor.submit(() -> {
+ taskManager.submitBatch(struct, mapping, batch);
+ });
+
+ Thread.sleep(50);
+ context.stopLoading();
+ semaphore.release();
+
+ blocked.get(5, TimeUnit.SECONDS);
+
+ taskManager.waitFinished();
+
+ Assert.assertTrue(context.stopped());
+ Assert.assertEquals(0L, flightingCount(metrics));
+ int expectedPermits = 1 + options.batchInsertThreads;
+ Assert.assertEquals(expectedPermits,
+ getSemaphorePermits(taskManager,
"batchSemaphore"));
+ } finally {
+ executor.shutdownNow();
+ taskManager.shutdown();
+ }
+ }
+
+ private static void waitStopped(LoadContext context, long timeout,
+ TimeUnit unit) throws Exception {
+ long deadline = System.nanoTime() + unit.toNanos(timeout);
+ while (!context.stopped() && System.nanoTime() < deadline) {
+ Thread.sleep(10);
+ }
+ Assert.assertTrue(context.stopped());
+ }
+
+ private static long flightingCount(LoadMetrics metrics)
+ throws Exception {
+ Field field = LoadMetrics.class.getDeclaredField("flightingNums");
+ field.setAccessible(true);
+ LongAdder adder = (LongAdder) field.get(metrics);
+ return adder.longValue();
+ }
+
+ private static LoadContext newTestContext(LoadOptions options)
+ throws Exception {
+ LoadContext context = (LoadContext)
allocateInstance(LoadContext.class);
+ setField(context, "timestamp", "test");
+ setField(context, "closed", false);
+ setField(context, "stopped", false);
+ setField(context, "noError", true);
+ setField(context, "options", options);
+ setField(context, "summary", new LoadSummary());
+ setField(context, "oldProgress", new LoadProgress());
+ setField(context, "newProgress", new LoadProgress());
+ setField(context, "loggers", new ConcurrentHashMap<>());
+
+ HugeClient client = (HugeClient) allocateInstance(HugeClient.class);
+ setField(context, "client", client);
+ setField(context, "indirectClient", client);
+ setField(context, "schemaCache", null);
+ setField(context, "parseGroup", null);
+ return context;
+ }
+
+ private static Object allocateInstance(Class<?> type) throws Exception {
+ Object unsafe = unsafe();
+ Method method = unsafe.getClass()
+ .getMethod("allocateInstance", Class.class);
+ return method.invoke(unsafe, type);
+ }
+
+ private static Object unsafe() throws Exception {
+ Class<?> unsafeClass;
+ try {
+ unsafeClass = Class.forName("sun.misc.Unsafe");
+ } catch (ClassNotFoundException e) {
+ unsafeClass = Class.forName("jdk.internal.misc.Unsafe");
+ }
+ Field field = unsafeClass.getDeclaredField("theUnsafe");
+ field.setAccessible(true);
+ return field.get(null);
+ }
+
+ private static void setField(Object target, String name, Object value)
+ throws Exception {
+ Field field = target.getClass().getDeclaredField(name);
+ field.setAccessible(true);
+ field.set(target, value);
+ }
+
+ private static int getSemaphorePermits(Object target, String name)
+ throws Exception {
+ Field field = target.getClass().getDeclaredField(name);
+ field.setAccessible(true);
+ Semaphore semaphore = (Semaphore) field.get(target);
+ return semaphore.availablePermits();
+ }
+
+ private static Semaphore getSemaphore(Object target, String name)
+ throws Exception {
+ Field field = target.getClass().getDeclaredField(name);
+ field.setAccessible(true);
+ return (Semaphore) field.get(target);
+ }
+
+ private static GraphManager newFailingConcurrentGraphManager()
+ throws Exception {
+ return (GraphManager)
allocateInstance(FailingConcurrentGraphManager.class);
+ }
+
+ private static GraphManager newFailingBatchGraphManager() throws Exception
{
+ return (GraphManager) allocateInstance(FailingBatchGraphManager.class);
+ }
+
+ private static GraphManager newSimpleGraphManager() throws Exception {
+ return (GraphManager) allocateInstance(SimpleGraphManager.class);
+ }
+
+ private static final class SimpleGraphManager extends GraphManager {
+
+ private SimpleGraphManager() {
+ super(null, null, null);
+ }
+
+ @Override
+ public List<Edge> addEdges(List<Edge> edges, boolean checkVertex) {
+ return this.addEdges(edges);
+ }
+
+ @Override
+ public List<Edge> addEdges(List<Edge> edges) {
+ return edges;
+ }
+ }
+
+ private static final class FailingConcurrentGraphManager extends
GraphManager {
+
+ private static final AtomicInteger BATCH_CALLS = new AtomicInteger();
+ private static volatile CountDownLatch FIRST_STARTED;
+ private static volatile CountDownLatch ALLOW_FIRST_FINISH;
+ private static volatile CountDownLatch FAILURE_CALLED;
+
+ private FailingConcurrentGraphManager() {
+ super(null, null, null);
+ }
+
+ @Override
+ public List<Edge> addEdges(List<Edge> edges, boolean checkVertex) {
+ return this.addEdges(edges);
+ }
+
+ @Override
+ public List<Edge> addEdges(List<Edge> edges) {
+ int call = BATCH_CALLS.incrementAndGet();
+ if (call == 1) {
+ CountDownLatch started = FIRST_STARTED;
+ if (started != null) {
+ started.countDown();
+ }
+ await(ALLOW_FIRST_FINISH);
+ return edges;
+ }
+ if (call == 2) {
+ CountDownLatch failed = FAILURE_CALLED;
+ if (failed != null) {
+ failed.countDown();
+ }
+ throw new RuntimeException("batch insert failure");
+ }
+ return edges;
+ }
+
+ private void await(CountDownLatch latch) {
+ if (latch == null) {
+ return;
+ }
+ try {
+ latch.await(5, TimeUnit.SECONDS);
+ } catch (InterruptedException ignored) {
+ // Let the task finish on interruption.
+ Thread.currentThread().interrupt();
+ }
+ }
+ }
+
+ private static final class FailingBatchGraphManager extends GraphManager {
+
+ private static final AtomicInteger BATCH_CALLS = new AtomicInteger();
+ private static final AtomicInteger SINGLE_CALLS = new AtomicInteger();
+
+ private FailingBatchGraphManager() {
+ super(null, null, null);
+ }
+
+ @Override
+ public List<Edge> addEdges(List<Edge> edges, boolean checkVertex) {
+ return this.addEdges(edges);
+ }
+
+ @Override
+ public List<Edge> addEdges(List<Edge> edges) {
+ if (edges.size() > 1) {
+ BATCH_CALLS.incrementAndGet();
+ throw new RuntimeException("batch insert failure");
+ }
+ SINGLE_CALLS.addAndGet(edges.size());
+ return edges;
+ }
+ }
+}