This is an automated email from the ASF dual-hosted git repository. xtsong pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/flink-agents.git
commit 8ac623d09f15b3a5c93b40cfb1027721e9210b18 Author: sxnan <[email protected]> AuthorDate: Wed Jan 14 16:29:17 2026 +0800 [runtime] Support Java async execution --- .../flink/agents/api/context/DurableCallable.java | 49 ++ .../flink/agents/api/context/RunnerContext.java | 23 + dist/pom.xml | 7 + .../pom.xml | 26 +- .../integration/test/AsyncExecutionAgent.java | 338 +++++++++++++ .../integration/test/AsyncExecutionTest.java | 390 +++++++++++++++ runtime/pom.xml | 62 +++ .../runtime/async/ContinuationActionExecutor.java | 69 +++ .../agents/runtime/async/ContinuationContext.java | 26 + .../runtime/context/JavaRunnerContextImpl.java | 105 +++++ .../agents/runtime/context/RunnerContextImpl.java | 144 ++++++ .../runtime/operator/ActionExecutionOperator.java | 41 +- .../agents/runtime/operator/JavaActionTask.java | 45 +- .../runtime/async/ContinuationActionExecutor.java | 162 +++++++ .../agents/runtime/async/ContinuationContext.java | 62 +++ .../flink/agents/runtime/memory/MemoryRefTest.java | 11 + .../operator/ActionExecutionOperatorTest.java | 523 +++++++++++++++++++++ 17 files changed, 2067 insertions(+), 16 deletions(-) diff --git a/api/src/main/java/org/apache/flink/agents/api/context/DurableCallable.java b/api/src/main/java/org/apache/flink/agents/api/context/DurableCallable.java new file mode 100644 index 00000000..fb77a747 --- /dev/null +++ b/api/src/main/java/org/apache/flink/agents/api/context/DurableCallable.java @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.agents.api.context; + +/** + * A callable interface for durable execution that requires a stable identifier. + * + * <p>This interface is used with {@link RunnerContext#durableExecute} and {@link + * RunnerContext#durableExecuteAsync} to ensure that each durable call has a stable, unique + * identifier that persists across job restarts. + * + * @param <T> the type of the result + */ +public interface DurableCallable<T> { + + /** + * Returns a stable identifier for this durable call. + * + * <p>This identifier must be unique within the action and deterministic for the same logical + * operation. The ID is used to match cached results during recovery. + */ + String getId(); + + /** Returns the class of the result for deserialization during recovery. */ + Class<T> getResultClass(); + + /** + * Executes the durable operation and returns the result. + * + * <p>This method will be called only if there is no cached result for this call. The result + * must be JSON-serializable. + */ + T call() throws Exception; +} diff --git a/api/src/main/java/org/apache/flink/agents/api/context/RunnerContext.java b/api/src/main/java/org/apache/flink/agents/api/context/RunnerContext.java index c7481690..0810752a 100644 --- a/api/src/main/java/org/apache/flink/agents/api/context/RunnerContext.java +++ b/api/src/main/java/org/apache/flink/agents/api/context/RunnerContext.java @@ -109,6 +109,29 @@ public interface RunnerContext { */ Object getActionConfigValue(String key); + /** + * Synchronously executes the provided callable with durable execution support. + * + * <p>The result will be stored and returned from cache during job recovery. The callable is + * executed synchronously, blocking the operator until completion. + * + * <p>Access to memory and sendEvent are prohibited within the callable. + */ + <T> T durableExecute(DurableCallable<T> callable) throws Exception; + + /** + * Asynchronously executes the provided callable with durable execution support. + * + * <p>On JDK 21+, this method uses Continuation to yield the current action execution, submits + * the callable to a thread pool, and resumes when complete. On JDK < 21, this falls back to + * synchronous execution. + * + * <p>The result will be stored and returned from cache during job recovery. + * + * <p>Access to memory and sendEvent are prohibited within the callable. + */ + <T> T durableExecuteAsync(DurableCallable<T> callable) throws Exception; + /** Clean up the resource. */ void close() throws Exception; } diff --git a/dist/pom.xml b/dist/pom.xml index f730e1d6..28bc748f 100644 --- a/dist/pom.xml +++ b/dist/pom.xml @@ -102,6 +102,13 @@ under the License. </excludes> </filter> </filters> + <transformers> + <transformer implementation="org.apache.maven.plugins.shade.resource.ManifestResourceTransformer"> + <manifestEntries> + <Multi-Release>true</Multi-Release> + </manifestEntries> + </transformer> + </transformers> </configuration> </execution> </executions> diff --git a/e2e-test/flink-agents-end-to-end-tests-integration/pom.xml b/e2e-test/flink-agents-end-to-end-tests-integration/pom.xml index 27e7d7ab..bf214ed7 100644 --- a/e2e-test/flink-agents-end-to-end-tests-integration/pom.xml +++ b/e2e-test/flink-agents-end-to-end-tests-integration/pom.xml @@ -32,7 +32,9 @@ under the License. <flink.1.20.version>1.20.3</flink.1.20.version> <flink.2.0.version>2.0.1</flink.2.0.version> <flink.2.1.version>2.1.1</flink.2.1.version> - <flink.2.2.version>2.2.0</flink.2.2.version> + + <flink.version>2.2.0</flink.version> + <flink.agents.dist.artifactId>flink-agents-dist-flink-2.2</flink.agents.dist.artifactId> </properties> <dependencies> @@ -109,16 +111,24 @@ under the License. </dependencies> <profiles> - <!-- Flink 2.2 Profile(Default) --> + <!-- JDK 21+ Profile for Continuation API --> <profile> - <id>flink-2.2</id> + <id>java-21</id> <activation> - <activeByDefault>true</activeByDefault> + <jdk>[21,)</jdk> </activation> - <properties> - <flink.version>${flink.2.2.version}</flink.version> - <flink.agents.dist.artifactId>flink-agents-dist-flink-2.2</flink.agents.dist.artifactId> - </properties> + <build> + <plugins> + <plugin> + <groupId>org.apache.maven.plugins</groupId> + <artifactId>maven-surefire-plugin</artifactId> + <configuration> + <!-- Add JVM args for JDK 21+ Continuation API access --> + <argLine>--add-exports java.base/jdk.internal.vm=ALL-UNNAMED</argLine> + </configuration> + </plugin> + </plugins> + </build> </profile> <!-- Flink 1.20 Profile --> diff --git a/e2e-test/flink-agents-end-to-end-tests-integration/src/test/java/org/apache/flink/agents/integration/test/AsyncExecutionAgent.java b/e2e-test/flink-agents-end-to-end-tests-integration/src/test/java/org/apache/flink/agents/integration/test/AsyncExecutionAgent.java new file mode 100644 index 00000000..5142a385 --- /dev/null +++ b/e2e-test/flink-agents-end-to-end-tests-integration/src/test/java/org/apache/flink/agents/integration/test/AsyncExecutionAgent.java @@ -0,0 +1,338 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.agents.integration.test; + +import org.apache.flink.agents.api.Event; +import org.apache.flink.agents.api.InputEvent; +import org.apache.flink.agents.api.OutputEvent; +import org.apache.flink.agents.api.agents.Agent; +import org.apache.flink.agents.api.annotation.Action; +import org.apache.flink.agents.api.context.DurableCallable; +import org.apache.flink.agents.api.context.MemoryObject; +import org.apache.flink.agents.api.context.RunnerContext; +import org.apache.flink.api.java.functions.KeySelector; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Agent definition for testing async execution functionality. + * + * <p>This agent demonstrates the usage of {@code durableExecuteAsync} for performing long-running + * operations without blocking the mailbox thread. + */ +public class AsyncExecutionAgent { + + private static final Logger LOG = LoggerFactory.getLogger(AsyncExecutionAgent.class); + + /** Simple request data class. */ + public static class AsyncRequest { + public final int id; + public final String data; + public final int sleepTimeMs; + + public AsyncRequest(int id, String data) { + this(id, data, 100); // Default sleep time + } + + public AsyncRequest(int id, String data, int sleepTimeMs) { + this.id = id; + this.data = data; + this.sleepTimeMs = sleepTimeMs; + } + + @Override + public String toString() { + return String.format( + "AsyncRequest{id=%d, data='%s', sleepTimeMs=%d}", id, data, sleepTimeMs); + } + } + + /** Key selector for extracting keys from AsyncRequest. */ + public static class AsyncRequestKeySelector implements KeySelector<AsyncRequest, Integer> { + @Override + public Integer getKey(AsyncRequest request) { + return request.id; + } + } + + /** Custom event type for internal agent communication. */ + public static class AsyncProcessedEvent extends Event { + private final String processedResult; + + public AsyncProcessedEvent(String processedResult) { + this.processedResult = processedResult; + } + + public String getProcessedResult() { + return processedResult; + } + } + + /** + * Agent that uses durableExecuteAsync for simulating slow operations. + * + * <p>On JDK 21+, this uses Continuation API for true async execution. On JDK < 21, this + * falls back to synchronous execution. + */ + public static class SimpleAsyncAgent extends Agent { + + @Action(listenEvents = {InputEvent.class}) + public static void processInput(Event event, RunnerContext ctx) throws Exception { + InputEvent inputEvent = (InputEvent) event; + AsyncRequest request = (AsyncRequest) inputEvent.getInput(); + + String result = + ctx.durableExecuteAsync( + new DurableCallable<String>() { + @Override + public String getId() { + return "simple-async-process"; + } + + @Override + public Class<String> getResultClass() { + return String.class; + } + + @Override + public String call() { + try { + Thread.sleep(100); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + return "Processed: " + request.data.toUpperCase(); + } + }); + + MemoryObject stm = ctx.getShortTermMemory(); + stm.set("lastResult", result); + + ctx.sendEvent(new AsyncProcessedEvent(result)); + } + + /** + * Action that handles processed events and generates output. + * + * @param event The processed event + * @param ctx The runner context for sending events + */ + @Action(listenEvents = {AsyncProcessedEvent.class}) + public static void generateOutput(Event event, RunnerContext ctx) throws Exception { + AsyncProcessedEvent processedEvent = (AsyncProcessedEvent) event; + + MemoryObject stm = ctx.getShortTermMemory(); + String lastResult = (String) stm.get("lastResult").getValue(); + + String output = + String.format( + "AsyncResult: %s | MemoryCheck: %s", + processedEvent.getProcessedResult(), lastResult); + ctx.sendEvent(new OutputEvent(output)); + } + } + + /** Agent that chains multiple durableExecuteAsync calls. */ + public static class MultiAsyncAgent extends Agent { + + @Action(listenEvents = {InputEvent.class}) + public static void processWithMultipleAsync(Event event, RunnerContext ctx) + throws Exception { + InputEvent inputEvent = (InputEvent) event; + AsyncRequest request = (AsyncRequest) inputEvent.getInput(); + + String step1Result = + ctx.durableExecuteAsync( + new DurableCallable<String>() { + @Override + public String getId() { + return "multi-async-step1"; + } + + @Override + public Class<String> getResultClass() { + return String.class; + } + + @Override + public String call() { + try { + Thread.sleep(100); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + return "Step1:" + request.data; + } + }); + + String step2Result = + ctx.durableExecuteAsync( + new DurableCallable<String>() { + @Override + public String getId() { + return "multi-async-step2"; + } + + @Override + public Class<String> getResultClass() { + return String.class; + } + + @Override + public String call() { + try { + Thread.sleep(100); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + return step1Result + "|Step2:processed"; + } + }); + + String finalResult = + ctx.durableExecuteAsync( + new DurableCallable<String>() { + @Override + public String getId() { + return "multi-async-step3"; + } + + @Override + public Class<String> getResultClass() { + return String.class; + } + + @Override + public String call() { + try { + Thread.sleep(100); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + return step2Result + "|Step3:done"; + } + }); + + MemoryObject stm = ctx.getShortTermMemory(); + stm.set("chainedResult", finalResult); + + ctx.sendEvent(new OutputEvent("MultiAsync[" + finalResult + "]")); + } + } + + /** Agent that uses durableExecuteAsync with configurable sleep time. */ + public static class TimedAsyncAgent extends Agent { + + private final int sleepTimeMs; + private final String timestampDir; + + public TimedAsyncAgent(int sleepTimeMs) { + this(sleepTimeMs, null); + } + + public TimedAsyncAgent(int sleepTimeMs, String timestampDir) { + this.sleepTimeMs = sleepTimeMs; + this.timestampDir = timestampDir; + } + + public int getSleepTimeMs() { + return sleepTimeMs; + } + + public String getTimestampDir() { + return timestampDir; + } + + @Action(listenEvents = {InputEvent.class}) + public static void processWithTiming(Event event, RunnerContext ctx) throws Exception { + InputEvent inputEvent = (InputEvent) event; + AsyncRequest request = (AsyncRequest) inputEvent.getInput(); + + String result = + ctx.durableExecuteAsync( + new DurableCallable<String>() { + @Override + public String getId() { + return "timed-async-" + request.id; + } + + @Override + public Class<String> getResultClass() { + return String.class; + } + + @Override + public String call() { + long asyncStartTime = System.currentTimeMillis(); + LOG.info("{} Async call start {}", request.id, asyncStartTime); + try { + Thread.sleep(request.sleepTimeMs); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + long asyncEndTime = System.currentTimeMillis(); + LOG.info("{} Async call end {}", request.id, asyncEndTime); + return String.format( + "key=%d,start=%d,end=%d", + request.id, asyncStartTime, asyncEndTime); + } + }); + + ctx.sendEvent(new OutputEvent("TimedAsync[" + result + "]")); + } + } + + /** Agent that uses durableExecute (sync) for simulating slow operations. */ + public static class SyncDurableAgent extends Agent { + + @Action(listenEvents = {InputEvent.class}) + public static void processInputSync(Event event, RunnerContext ctx) throws Exception { + InputEvent inputEvent = (InputEvent) event; + AsyncRequest request = (AsyncRequest) inputEvent.getInput(); + + String result = + ctx.durableExecute( + new DurableCallable<String>() { + @Override + public String getId() { + return "sync-durable-process"; + } + + @Override + public Class<String> getResultClass() { + return String.class; + } + + @Override + public String call() { + try { + Thread.sleep(50); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + return "SyncProcessed: " + request.data.toUpperCase(); + } + }); + + MemoryObject stm = ctx.getShortTermMemory(); + stm.set("syncResult", result); + + ctx.sendEvent(new OutputEvent("SyncDurable[" + result + "]")); + } + } +} diff --git a/e2e-test/flink-agents-end-to-end-tests-integration/src/test/java/org/apache/flink/agents/integration/test/AsyncExecutionTest.java b/e2e-test/flink-agents-end-to-end-tests-integration/src/test/java/org/apache/flink/agents/integration/test/AsyncExecutionTest.java new file mode 100644 index 00000000..271b067f --- /dev/null +++ b/e2e-test/flink-agents-end-to-end-tests-integration/src/test/java/org/apache/flink/agents/integration/test/AsyncExecutionTest.java @@ -0,0 +1,390 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.agents.integration.test; + +import org.apache.flink.agents.api.AgentsExecutionEnvironment; +import org.apache.flink.agents.runtime.async.ContinuationActionExecutor; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.util.CloseableIterator; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; + +import java.util.ArrayList; +import java.util.List; + +/** + * End-to-end tests for Java async execution functionality. + * + * <p>These tests verify that {@code durableExecuteAsync} works correctly for Java actions. + */ +public class AsyncExecutionTest { + + /** + * Tests that a simple async action works correctly. + * + * <p>The agent uses durableExecuteAsync to simulate a slow operation, then accesses memory and + * sends an event. + */ + @Test + public void testSimpleAsyncExecution() throws Exception { + StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(); + env.setParallelism(1); + + // Create input DataStream + DataStream<AsyncExecutionAgent.AsyncRequest> inputStream = + env.fromElements( + new AsyncExecutionAgent.AsyncRequest(1, "hello"), + new AsyncExecutionAgent.AsyncRequest(2, "world"), + new AsyncExecutionAgent.AsyncRequest(1, "flink")); + + // Create agents execution environment + AgentsExecutionEnvironment agentsEnv = + AgentsExecutionEnvironment.getExecutionEnvironment(env); + + // Apply agent to the DataStream + DataStream<Object> outputStream = + agentsEnv + .fromDataStream( + inputStream, new AsyncExecutionAgent.AsyncRequestKeySelector()) + .apply(new AsyncExecutionAgent.SimpleAsyncAgent()) + .toDataStream(); + + // Collect the results + CloseableIterator<Object> results = outputStream.collectAsync(); + + // Execute the pipeline + agentsEnv.execute(); + + // Verify results + List<String> outputList = new ArrayList<>(); + while (results.hasNext()) { + outputList.add(results.next().toString()); + } + results.close(); + + // Should have 3 outputs + Assertions.assertEquals(3, outputList.size()); + + // Each output should contain the async processed result + for (String output : outputList) { + Assertions.assertTrue( + output.contains("AsyncResult:"), + "Output should contain async result: " + output); + Assertions.assertTrue( + output.contains("Processed:"), + "Output should contain processed data: " + output); + Assertions.assertTrue( + output.contains("MemoryCheck:"), + "Output should contain memory check: " + output); + } + } + + /** + * Tests that multiple executeAsync calls can be chained within a single action. + * + * <p>The agent performs three sequential async operations and combines their results. + */ + @Test + public void testMultipleAsyncCalls() throws Exception { + StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(); + env.setParallelism(1); + + // Create input DataStream + DataStream<AsyncExecutionAgent.AsyncRequest> inputStream = + env.fromElements( + new AsyncExecutionAgent.AsyncRequest(1, "test1"), + new AsyncExecutionAgent.AsyncRequest(2, "test2")); + + // Create agents execution environment + AgentsExecutionEnvironment agentsEnv = + AgentsExecutionEnvironment.getExecutionEnvironment(env); + + // Apply agent to the DataStream + DataStream<Object> outputStream = + agentsEnv + .fromDataStream( + inputStream, new AsyncExecutionAgent.AsyncRequestKeySelector()) + .apply(new AsyncExecutionAgent.MultiAsyncAgent()) + .toDataStream(); + + // Collect the results + CloseableIterator<Object> results = outputStream.collectAsync(); + + // Execute the pipeline + agentsEnv.execute(); + + // Verify results + List<String> outputList = new ArrayList<>(); + while (results.hasNext()) { + outputList.add(results.next().toString()); + } + results.close(); + + // Should have 2 outputs + Assertions.assertEquals(2, outputList.size()); + + // Each output should contain all three steps + for (String output : outputList) { + Assertions.assertTrue( + output.contains("Step1:"), "Output should contain Step1: " + output); + Assertions.assertTrue( + output.contains("Step2:"), "Output should contain Step2: " + output); + Assertions.assertTrue( + output.contains("Step3:"), "Output should contain Step3: " + output); + } + } + + /** + * Tests that async execution works correctly with multiple keys processed concurrently. + * + * <p>Different keys should be processed independently with their own async operations. + */ + @Test + public void testAsyncWithMultipleKeysHighLoad() throws Exception { + StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(); + env.setParallelism(2); // Use parallelism to test concurrent processing + + // Create input DataStream with multiple elements across different keys + List<AsyncExecutionAgent.AsyncRequest> requests = new ArrayList<>(); + for (int key = 0; key < 5; key++) { + for (int i = 0; i < 3; i++) { + requests.add(new AsyncExecutionAgent.AsyncRequest(key, "data-" + key + "-" + i)); + } + } + + DataStream<AsyncExecutionAgent.AsyncRequest> inputStream = env.fromCollection(requests); + + // Create agents execution environment + AgentsExecutionEnvironment agentsEnv = + AgentsExecutionEnvironment.getExecutionEnvironment(env); + + // Apply agent to the DataStream + DataStream<Object> outputStream = + agentsEnv + .fromDataStream( + inputStream, new AsyncExecutionAgent.AsyncRequestKeySelector()) + .apply(new AsyncExecutionAgent.SimpleAsyncAgent()) + .toDataStream(); + + // Collect the results + CloseableIterator<Object> results = outputStream.collectAsync(); + + // Execute the pipeline + agentsEnv.execute(); + + // Verify results + List<String> outputList = new ArrayList<>(); + while (results.hasNext()) { + outputList.add(results.next().toString()); + } + results.close(); + + // Should have 15 outputs (5 keys * 3 elements each) + Assertions.assertEquals(15, outputList.size()); + + // All outputs should be valid + for (String output : outputList) { + Assertions.assertTrue( + output.contains("AsyncResult:"), + "Output should contain async result: " + output); + } + } + + /** + * Tests that async execution on JDK 21+ actually executes tasks in parallel. + * + * <p>This test creates multiple tasks that each sleep for a fixed duration. Each async task + * records its start and end timestamps. We verify parallel execution by checking if the + * execution time ranges overlap. + * + * <p>On JDK 21+: Tasks run in parallel, their execution times overlap On JDK < 21: Tasks run + * sequentially, no overlap + */ + @Test + public void testAsyncExecutionIsActuallyParallel() throws Exception { + boolean continuationSupported = ContinuationActionExecutor.isContinuationSupported(); + int javaVersion = Runtime.version().feature(); + + System.out.println("=== Async Parallelism Test ==="); + System.out.println("Java version: " + javaVersion); + System.out.println("Continuation supported: " + continuationSupported); + + StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(); + env.setParallelism(1); // Single parallelism to test async within one operator + + // Create 3 requests with different keys, each will sleep 500ms + int numRequests = 3; + int sleepTimeMs = 500; + + List<AsyncExecutionAgent.AsyncRequest> requests = new ArrayList<>(); + for (int i = 0; i < numRequests; i++) { + requests.add( + new AsyncExecutionAgent.AsyncRequest(i, "parallel-test-" + i, sleepTimeMs)); + } + + DataStream<AsyncExecutionAgent.AsyncRequest> inputStream = env.fromCollection(requests); + + AgentsExecutionEnvironment agentsEnv = + AgentsExecutionEnvironment.getExecutionEnvironment(env); + + // Use TimedAsyncAgent which records timestamps + DataStream<Object> outputStream = + agentsEnv + .fromDataStream( + inputStream, new AsyncExecutionAgent.AsyncRequestKeySelector()) + .apply(new AsyncExecutionAgent.TimedAsyncAgent(sleepTimeMs)) + .toDataStream(); + + CloseableIterator<Object> results = outputStream.collectAsync(); + agentsEnv.execute(); + + // Parse execution timestamps from output + List<long[]> executionRanges = new ArrayList<>(); + while (results.hasNext()) { + String output = results.next().toString(); + // Parse: TimedAsync[key=X,start=Y,end=Z] + java.util.regex.Pattern pattern = + java.util.regex.Pattern.compile("start=(\\d+),end=(\\d+)"); + java.util.regex.Matcher matcher = pattern.matcher(output); + if (matcher.find()) { + long start = Long.parseLong(matcher.group(1)); + long end = Long.parseLong(matcher.group(2)); + executionRanges.add(new long[] {start, end}); + System.out.println("Task execution: start=" + start + ", end=" + end); + } + } + results.close(); + + Assertions.assertEquals(numRequests, executionRanges.size()); + + // Check for overlap between execution ranges + // Two ranges [s1, e1] and [s2, e2] overlap if s1 < e2 && s2 < e1 + int overlapCount = 0; + for (int i = 0; i < executionRanges.size(); i++) { + for (int j = i + 1; j < executionRanges.size(); j++) { + long[] range1 = executionRanges.get(i); + long[] range2 = executionRanges.get(j); + boolean overlaps = range1[0] < range2[1] && range2[0] < range1[1]; + if (overlaps) { + overlapCount++; + System.out.println( + "Overlap detected: [" + + range1[0] + + "," + + range1[1] + + "] and [" + + range2[0] + + "," + + range2[1] + + "]"); + } + } + } + + System.out.println("Total overlapping pairs: " + overlapCount); + + String classLocation = + ContinuationActionExecutor.class + .getProtectionDomain() + .getCodeSource() + .getLocation() + .toString(); + System.out.println("Class loaded from: " + classLocation); + + if (continuationSupported && javaVersion >= 21) { + // On JDK 21+, all tasks should overlap (parallel execution) + // With 3 tasks, we expect 3 overlapping pairs: (0,1), (0,2), (1,2) + int expectedOverlaps = (numRequests * (numRequests - 1)) / 2; + Assertions.assertTrue( + overlapCount >= expectedOverlaps - 1, // Allow some tolerance + String.format( + "On JDK 21+, async tasks should run in parallel (overlapping). " + + "Expected at least %d overlapping pairs, but found %d.", + expectedOverlaps - 1, overlapCount)); + System.out.println("✓ Async execution is PARALLEL (as expected on JDK 21+)"); + } else { + // On JDK < 21, tasks run sequentially - no overlap expected + Assertions.assertEquals( + 0, + overlapCount, + String.format( + "On JDK < 21, async tasks should run sequentially (no overlap). " + + "But found %d overlapping pairs.", + overlapCount)); + System.out.println("✓ Async execution is SEQUENTIAL (as expected on JDK < 21)"); + } + + System.out.println("=== Test Passed ==="); + } + + /** + * Tests that durableExecute (sync) works correctly. + * + * <p>The agent uses durableExecute to simulate a slow synchronous operation. + */ + @Test + public void testDurableExecuteSync() throws Exception { + StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(); + env.setParallelism(1); + + // Create input DataStream + DataStream<AsyncExecutionAgent.AsyncRequest> inputStream = + env.fromElements( + new AsyncExecutionAgent.AsyncRequest(1, "hello"), + new AsyncExecutionAgent.AsyncRequest(2, "world")); + + // Create agents execution environment + AgentsExecutionEnvironment agentsEnv = + AgentsExecutionEnvironment.getExecutionEnvironment(env); + + // Apply agent to the DataStream + DataStream<Object> outputStream = + agentsEnv + .fromDataStream( + inputStream, new AsyncExecutionAgent.AsyncRequestKeySelector()) + .apply(new AsyncExecutionAgent.SyncDurableAgent()) + .toDataStream(); + + // Collect the results + CloseableIterator<Object> results = outputStream.collectAsync(); + + // Execute the pipeline + agentsEnv.execute(); + + // Verify results + List<String> outputList = new ArrayList<>(); + while (results.hasNext()) { + outputList.add(results.next().toString()); + } + results.close(); + + // Should have 2 outputs + Assertions.assertEquals(2, outputList.size()); + + // Each output should contain the sync processed result + for (String output : outputList) { + Assertions.assertTrue( + output.contains("SyncDurable["), + "Output should contain sync durable result: " + output); + Assertions.assertTrue( + output.contains("SyncProcessed:"), + "Output should contain processed data: " + output); + } + } +} diff --git a/runtime/pom.xml b/runtime/pom.xml index dfa5be86..e2044828 100644 --- a/runtime/pom.xml +++ b/runtime/pom.xml @@ -142,4 +142,66 @@ under the License. <scope>provided</scope> </dependency> </dependencies> + + <build> + <plugins> + <!-- Configure jar plugin for Multi-Release JAR --> + <plugin> + <groupId>org.apache.maven.plugins</groupId> + <artifactId>maven-jar-plugin</artifactId> + <version>3.3.0</version> + <configuration> + <archive> + <manifestEntries> + <Multi-Release>true</Multi-Release> + </manifestEntries> + </archive> + </configuration> + </plugin> + </plugins> + </build> + + <profiles> + <!-- Profile for building JDK 21 specific classes when running on JDK 21+ --> + <profile> + <id>java-21</id> + <activation> + <jdk>[21,)</jdk> + </activation> + <build> + <plugins> + <!-- Compile JDK 21 specific sources --> + <plugin> + <groupId>org.apache.maven.plugins</groupId> + <artifactId>maven-compiler-plugin</artifactId> + <version>3.11.0</version> + <executions> + <execution> + <id>compile-java21</id> + <phase>compile</phase> + <goals> + <goal>compile</goal> + </goals> + <configuration> + <!-- Must not use release with add-exports, use source/target instead --> + <release combine.self="override"/> + <source>21</source> + <target>21</target> + <fork>true</fork> + <compileSourceRoots> + <compileSourceRoot>${project.basedir}/src/main/java21</compileSourceRoot> + </compileSourceRoots> + <outputDirectory>${project.build.outputDirectory}/META-INF/versions/21</outputDirectory> + <compilerArgs> + <arg>--add-exports</arg> + <arg>java.base/jdk.internal.vm=ALL-UNNAMED</arg> + </compilerArgs> + </configuration> + </execution> + </executions> + </plugin> + </plugins> + </build> + </profile> + </profiles> </project> \ No newline at end of file diff --git a/runtime/src/main/java/org/apache/flink/agents/runtime/async/ContinuationActionExecutor.java b/runtime/src/main/java/org/apache/flink/agents/runtime/async/ContinuationActionExecutor.java new file mode 100644 index 00000000..c0b4cee8 --- /dev/null +++ b/runtime/src/main/java/org/apache/flink/agents/runtime/async/ContinuationActionExecutor.java @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.agents.runtime.async; + +import java.util.function.Supplier; + +/** + * Executor for Java actions that supports asynchronous execution. + * + * <p>This is the JDK 11 version that falls back to synchronous execution. On JDK 21+, the + * Multi-release JAR will use a version that leverages Continuation API for true async execution. + */ +public class ContinuationActionExecutor { + + /** Creates a new ContinuationActionExecutor. */ + public ContinuationActionExecutor() {} + + /** + * Executes the action. In JDK 11, this simply runs the action synchronously. + * + * @param context the continuation context + * @param action the action to execute + * @return true if the action completed, false if it yielded (always true in JDK 11) + */ + public boolean executeAction(ContinuationContext context, Runnable action) { + action.run(); + return true; + } + + /** + * Asynchronously executes the provided supplier. In JDK 11, this falls back to synchronous + * execution. + * + * @param context the continuation context + * @param supplier the supplier to execute + * @param <T> the result type + * @return the result of the supplier + */ + public <T> T executeAsync(ContinuationContext context, Supplier<T> supplier) { + // JDK 11: Fall back to synchronous execution + return supplier.get(); + } + + public void close() {} + + /** + * Returns whether continuation-based async execution is supported. + * + * @return true if Continuation API is available (JDK 21+), false otherwise + */ + public static boolean isContinuationSupported() { + return false; + } +} diff --git a/runtime/src/main/java/org/apache/flink/agents/runtime/async/ContinuationContext.java b/runtime/src/main/java/org/apache/flink/agents/runtime/async/ContinuationContext.java new file mode 100644 index 00000000..6d0d7e09 --- /dev/null +++ b/runtime/src/main/java/org/apache/flink/agents/runtime/async/ContinuationContext.java @@ -0,0 +1,26 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.agents.runtime.async; + +/** + * Marker context for continuation execution. + * + * <p>This class is part of the base (JDK 11) sources. JDK 21 provides a multi-release variant with + * continuation-specific fields. + */ +public class ContinuationContext {} diff --git a/runtime/src/main/java/org/apache/flink/agents/runtime/context/JavaRunnerContextImpl.java b/runtime/src/main/java/org/apache/flink/agents/runtime/context/JavaRunnerContextImpl.java new file mode 100644 index 00000000..a18bb30b --- /dev/null +++ b/runtime/src/main/java/org/apache/flink/agents/runtime/context/JavaRunnerContextImpl.java @@ -0,0 +1,105 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.agents.runtime.context; + +import org.apache.flink.agents.api.context.DurableCallable; +import org.apache.flink.agents.plan.AgentPlan; +import org.apache.flink.agents.runtime.async.ContinuationActionExecutor; +import org.apache.flink.agents.runtime.async.ContinuationContext; +import org.apache.flink.agents.runtime.metrics.FlinkAgentsMetricGroupImpl; + +import java.util.function.Supplier; + +/** + * Java-specific implementation of RunnerContext that includes ContinuationActionExecutor for async + * execution support. + */ +public class JavaRunnerContextImpl extends RunnerContextImpl { + + private final ContinuationActionExecutor continuationExecutor; + private ContinuationContext continuationContext; + + public JavaRunnerContextImpl( + FlinkAgentsMetricGroupImpl agentMetricGroup, + Runnable mailboxThreadChecker, + AgentPlan agentPlan, + String jobIdentifier, + ContinuationActionExecutor continuationExecutor) { + super(agentMetricGroup, mailboxThreadChecker, agentPlan, jobIdentifier); + this.continuationExecutor = continuationExecutor; + } + + public ContinuationActionExecutor getContinuationExecutor() { + return continuationExecutor; + } + + public void setContinuationContext(ContinuationContext continuationContext) { + this.continuationContext = continuationContext; + } + + public ContinuationContext getContinuationContext() { + return continuationContext; + } + + @Override + public <T> T durableExecuteAsync(DurableCallable<T> callable) throws Exception { + String functionId = callable.getId(); + String argsDigest = ""; + + java.util.Optional<T> cachedResult = + tryGetCachedResult(functionId, argsDigest, callable.getResultClass()); + if (cachedResult.isPresent()) { + return cachedResult.get(); + } + + Supplier<T> wrappedSupplier = + () -> { + T innerResult = null; + Exception innerException = null; + try { + innerResult = callable.call(); + } catch (Exception e) { + innerException = e; + } + + if (innerException != null) { + throw new DurableExecutionRuntimeException(innerException); + } + return innerResult; + }; + + T result = null; + Exception originalException = null; + try { + if (continuationExecutor == null || continuationContext == null) { + result = wrappedSupplier.get(); + } else { + result = continuationExecutor.executeAsync(continuationContext, wrappedSupplier); + } + } catch (DurableExecutionRuntimeException e) { + originalException = (Exception) e.getCause(); + } + + recordDurableCompletion(functionId, argsDigest, result, originalException); + + if (originalException != null) { + throw originalException; + } + return result; + } +} diff --git a/runtime/src/main/java/org/apache/flink/agents/runtime/context/RunnerContextImpl.java b/runtime/src/main/java/org/apache/flink/agents/runtime/context/RunnerContextImpl.java index 8a946f2d..01b6e4c4 100644 --- a/runtime/src/main/java/org/apache/flink/agents/runtime/context/RunnerContextImpl.java +++ b/runtime/src/main/java/org/apache/flink/agents/runtime/context/RunnerContextImpl.java @@ -18,8 +18,10 @@ package org.apache.flink.agents.runtime.context; import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.ObjectMapper; import org.apache.flink.agents.api.Event; import org.apache.flink.agents.api.configuration.ReadableConfiguration; +import org.apache.flink.agents.api.context.DurableCallable; import org.apache.flink.agents.api.context.MemoryObject; import org.apache.flink.agents.api.context.MemoryUpdate; import org.apache.flink.agents.api.context.RunnerContext; @@ -47,12 +49,17 @@ import java.util.ArrayList; import java.util.LinkedList; import java.util.List; import java.util.Map; +import java.util.Optional; +import java.util.function.Supplier; /** * The implementation class of {@link RunnerContext}, which serves as the execution context for * actions. */ public class RunnerContextImpl implements RunnerContext { + + private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper(); + public static class MemoryContext { private final CachedMemoryStore sensoryMemStore; private final CachedMemoryStore shortTermMemStore; @@ -237,6 +244,143 @@ public class RunnerContextImpl implements RunnerContext { return agentPlan.getActionConfigValue(actionName, key); } + protected <T> Optional<T> tryGetCachedResult( + String functionId, String argsDigest, Class<T> resultClass) throws Exception { + Object[] cached = matchNextOrClearSubsequentCallResult(functionId, argsDigest); + if (cached != null && (Boolean) cached[0]) { + byte[] resultPayload = (byte[]) cached[1]; + byte[] exceptionPayload = (byte[]) cached[2]; + + if (exceptionPayload != null) { + DurableExecutionException cachedException = + OBJECT_MAPPER.readValue(exceptionPayload, DurableExecutionException.class); + throw cachedException.toException(); + } else if (resultPayload != null) { + return Optional.of(OBJECT_MAPPER.readValue(resultPayload, resultClass)); + } else { + return Optional.of(null); + } + } + return Optional.empty(); + } + + protected void recordDurableCompletion( + String functionId, String argsDigest, Object result, Exception exception) + throws Exception { + byte[] resultPayload = null; + byte[] exceptionPayload = null; + if (exception != null) { + exceptionPayload = + OBJECT_MAPPER.writeValueAsBytes( + DurableExecutionException.fromException(exception)); + } else if (result != null) { + resultPayload = OBJECT_MAPPER.writeValueAsBytes(result); + } + recordCallCompletion(functionId, argsDigest, resultPayload, exceptionPayload); + } + + @Override + public <T> T durableExecute(DurableCallable<T> callable) throws Exception { + String functionId = callable.getId(); + // argsDigest is empty because DurableCallable encapsulates all arguments internally + String argsDigest = ""; + + Optional<T> cachedResult = + tryGetCachedResult(functionId, argsDigest, callable.getResultClass()); + if (cachedResult.isPresent()) { + return cachedResult.get(); + } + + T result = null; + Exception exception = null; + try { + result = callable.call(); + } catch (Exception e) { + exception = e; + } + + recordDurableCompletion(functionId, argsDigest, result, exception); + + if (exception != null) { + throw exception; + } + return result; + } + + @Override + public <T> T durableExecuteAsync(DurableCallable<T> callable) throws Exception { + String functionId = callable.getId(); + // argsDigest is empty because DurableCallable encapsulates all arguments internally + String argsDigest = ""; + + Optional<T> cachedResult = + tryGetCachedResult(functionId, argsDigest, callable.getResultClass()); + if (cachedResult.isPresent()) { + return cachedResult.get(); + } + + Supplier<T> wrappedSupplier = + () -> { + T innerResult = null; + Exception innerException = null; + try { + innerResult = callable.call(); + } catch (Exception e) { + innerException = e; + } + + if (innerException != null) { + throw new DurableExecutionRuntimeException(innerException); + } + return innerResult; + }; + + T result = null; + Exception originalException = null; + try { + result = wrappedSupplier.get(); + } catch (DurableExecutionRuntimeException e) { + originalException = (Exception) e.getCause(); + } + + recordDurableCompletion(functionId, argsDigest, result, originalException); + + if (originalException != null) { + throw originalException; + } + return result; + } + + protected static class DurableExecutionRuntimeException extends RuntimeException { + DurableExecutionRuntimeException(Throwable cause) { + super(cause); + } + } + + /** Serializable exception info for durable execution persistence. */ + public static class DurableExecutionException { + private final String exceptionClass; + private final String message; + + public DurableExecutionException() { + this.exceptionClass = null; + this.message = null; + } + + public DurableExecutionException(String exceptionClass, String message) { + this.exceptionClass = exceptionClass; + this.message = message; + } + + public static DurableExecutionException fromException(Exception e) { + return new DurableExecutionException(e.getClass().getName(), e.getMessage()); + } + + public Exception toException() { + return new RuntimeException(exceptionClass + ": " + message); + } + } + @Override public void close() throws Exception { if (this.ltm != null) { diff --git a/runtime/src/main/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperator.java b/runtime/src/main/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperator.java index a11efc0e..6a6ae880 100644 --- a/runtime/src/main/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperator.java +++ b/runtime/src/main/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperator.java @@ -38,7 +38,10 @@ import org.apache.flink.agents.plan.resourceprovider.PythonResourceProvider; import org.apache.flink.agents.runtime.actionstate.ActionState; import org.apache.flink.agents.runtime.actionstate.ActionStateStore; import org.apache.flink.agents.runtime.actionstate.KafkaActionStateStore; +import org.apache.flink.agents.runtime.async.ContinuationActionExecutor; +import org.apache.flink.agents.runtime.async.ContinuationContext; import org.apache.flink.agents.runtime.context.ActionStatePersister; +import org.apache.flink.agents.runtime.context.JavaRunnerContextImpl; import org.apache.flink.agents.runtime.context.RunnerContextImpl; import org.apache.flink.agents.runtime.env.EmbeddedPythonEnvironment; import org.apache.flink.agents.runtime.env.PythonEnvironmentManager; @@ -196,6 +199,8 @@ public class ActionExecutionOperator<IN, OUT> extends AbstractStreamOperator<OUT private final transient Map<ActionTask, RunnerContextImpl.DurableExecutionContext> actionTaskDurableContexts; + private final transient Map<ActionTask, ContinuationContext> continuationContexts; + // Each job can only have one identifier and this identifier must be consistent across restarts. // We cannot use job id as the identifier here because user may change job id by // creating a savepoint, stop the job and then resume from savepoint. @@ -203,6 +208,8 @@ public class ActionExecutionOperator<IN, OUT> extends AbstractStreamOperator<OUT // Inspired by Apache Paimon. private transient String jobIdentifier; + private transient ContinuationActionExecutor continuationActionExecutor; + public ActionExecutionOperator( AgentPlan agentPlan, Boolean inputIsJava, @@ -219,6 +226,7 @@ public class ActionExecutionOperator<IN, OUT> extends AbstractStreamOperator<OUT this.checkpointIdToSeqNums = new HashMap<>(); this.actionTaskMemoryContexts = new HashMap<>(); this.actionTaskDurableContexts = new HashMap<>(); + this.continuationContexts = new HashMap<>(); OperatorUtils.setChainStrategy(this, ChainingStrategy.ALWAYS); } @@ -304,6 +312,9 @@ public class ActionExecutionOperator<IN, OUT> extends AbstractStreamOperator<OUT // init PythonActionExecutor and PythonResourceAdapter initPythonEnvironment(); + // init executor for Java async execution + continuationActionExecutor = new ContinuationActionExecutor(); + mailboxProcessor = getMailboxProcessor(); // Initialize the event logger if it is set. @@ -494,6 +505,7 @@ public class ActionExecutionOperator<IN, OUT> extends AbstractStreamOperator<OUT // finished. actionTaskMemoryContexts.remove(actionTask); actionTaskDurableContexts.remove(actionTask); + continuationContexts.remove(actionTask); maybePersistTaskResult( key, sequenceNumber, @@ -535,6 +547,12 @@ public class ActionExecutionOperator<IN, OUT> extends AbstractStreamOperator<OUT if (durableContext != null) { actionTaskDurableContexts.put(generatedActionTask, durableContext); } + if (actionTask.getRunnerContext() instanceof JavaRunnerContextImpl) { + continuationContexts.put( + generatedActionTask, + ((JavaRunnerContextImpl) actionTask.getRunnerContext()) + .getContinuationContext()); + } actionTasksKState.add(generatedActionTask); } @@ -694,6 +712,9 @@ public class ActionExecutionOperator<IN, OUT> extends AbstractStreamOperator<OUT if (runnerContext != null) { runnerContext.close(); } + if (continuationActionExecutor != null) { + continuationActionExecutor.close(); + } super.close(); } @@ -849,6 +870,18 @@ public class ActionExecutionOperator<IN, OUT> extends AbstractStreamOperator<OUT runnerContext.switchActionContext( actionTask.action.getName(), memoryContext, String.valueOf(key.hashCode())); + + if (runnerContext instanceof JavaRunnerContextImpl) { + ContinuationContext continuationContext; + if (continuationContexts.containsKey(actionTask)) { + // action task for async execution action, should retrieve intermediate results from + // map. + continuationContext = continuationContexts.get(actionTask); + } else { + continuationContext = new ContinuationContext(); + } + ((JavaRunnerContextImpl) runnerContext).setContinuationContext(continuationContext); + } actionTask.setRunnerContext(runnerContext); } @@ -1019,12 +1052,16 @@ public class ActionExecutionOperator<IN, OUT> extends AbstractStreamOperator<OUT private RunnerContextImpl createOrGetRunnerContext(Boolean isJava) { if (isJava) { if (runnerContext == null) { + if (continuationActionExecutor == null) { + continuationActionExecutor = new ContinuationActionExecutor(); + } runnerContext = - new RunnerContextImpl( + new JavaRunnerContextImpl( this.metricGroup, this::checkMailboxThread, this.agentPlan, - this.jobIdentifier); + this.jobIdentifier, + continuationActionExecutor); } return runnerContext; } else { diff --git a/runtime/src/main/java/org/apache/flink/agents/runtime/operator/JavaActionTask.java b/runtime/src/main/java/org/apache/flink/agents/runtime/operator/JavaActionTask.java index 65d8ef4a..11724ce6 100644 --- a/runtime/src/main/java/org/apache/flink/agents/runtime/operator/JavaActionTask.java +++ b/runtime/src/main/java/org/apache/flink/agents/runtime/operator/JavaActionTask.java @@ -20,17 +20,26 @@ package org.apache.flink.agents.runtime.operator; import org.apache.flink.agents.api.Event; import org.apache.flink.agents.plan.JavaFunction; import org.apache.flink.agents.plan.actions.Action; +import org.apache.flink.agents.runtime.context.JavaRunnerContextImpl; import org.apache.flink.agents.runtime.python.utils.PythonActionExecutor; +import java.util.Collections; + import static org.apache.flink.util.Preconditions.checkState; /** * A special {@link ActionTask} designed to execute a Java action task. * - * <p>Note that Java action currently do not support asynchronous execution. As a result, a Java - * action task will be invoked only once. + * <p>On JDK 21+, this task supports asynchronous execution via {@code executeAsync} in the action + * code. When the action yields for async execution, this task returns with {@code finished=false} + * and generates itself as the next task to continue execution. + * + * <p>On JDK < 21, async execution falls back to synchronous mode. */ public class JavaActionTask extends ActionTask { + + private boolean executionStarted = false; + public JavaActionTask(Object key, Event event, Action action) { super(key, event, action); checkState(action.getExec() instanceof JavaFunction); @@ -44,15 +53,39 @@ public class JavaActionTask extends ActionTask { action.getName(), event, key); - runnerContext.checkNoPendingEvents(); + + if (!executionStarted) { + runnerContext.checkNoPendingEvents(); + executionStarted = true; + } + + JavaRunnerContextImpl javaRunnerContext = (JavaRunnerContextImpl) runnerContext; + ClassLoader cl = Thread.currentThread().getContextClassLoader(); + boolean finished; try { Thread.currentThread().setContextClassLoader(userCodeClassLoader); - action.getExec().call(event, runnerContext); + finished = + javaRunnerContext + .getContinuationExecutor() + .executeAction( + javaRunnerContext.getContinuationContext(), + () -> { + try { + action.getExec().call(event, runnerContext); + } catch (Exception e) { + throw new RuntimeException(e); + } + }); } finally { Thread.currentThread().setContextClassLoader(cl); } - return new ActionTaskResult( - true, runnerContext.drainEvents(event.getSourceTimestamp()), null); + + if (finished) { + return new ActionTaskResult( + true, runnerContext.drainEvents(event.getSourceTimestamp()), null); + } else { + return new ActionTaskResult(false, Collections.emptyList(), this); + } } } diff --git a/runtime/src/main/java21/org/apache/flink/agents/runtime/async/ContinuationActionExecutor.java b/runtime/src/main/java21/org/apache/flink/agents/runtime/async/ContinuationActionExecutor.java new file mode 100644 index 00000000..eda859d6 --- /dev/null +++ b/runtime/src/main/java21/org/apache/flink/agents/runtime/async/ContinuationActionExecutor.java @@ -0,0 +1,162 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.agents.runtime.async; + +import jdk.internal.vm.Continuation; +import jdk.internal.vm.ContinuationScope; + +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; +import java.util.function.Supplier; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Executor for Java actions that supports asynchronous execution using JDK 21+ Continuation API. + * + * <p>This version uses {@code jdk.internal.vm.Continuation} to implement true async execution. + */ +public class ContinuationActionExecutor { + + private static final Logger LOG = LoggerFactory.getLogger(ContinuationActionExecutor.class); + + private static final ContinuationScope SCOPE = new ContinuationScope("FlinkAgentsAction"); + + private final ExecutorService asyncExecutor; + + public ContinuationActionExecutor() { + this.asyncExecutor = + Executors.newFixedThreadPool(Runtime.getRuntime().availableProcessors() * 2); + } + + /** + * Executes the action inside a Continuation. + * + * <p>If the action calls executeAsync and yields, this method checks if the async Future is + * done. If not done, returns false to indicate the action is not finished. If done, resumes the + * Continuation. + * + * @param context the continuation context for this action + * @param action the action to execute + * @return true if the action completed, false if waiting for async execution + */ + public boolean executeAction(ContinuationContext context, Runnable action) { + // Check if we have a pending async Future from previous yield + Future<?> pending = context.getPendingFuture(); + if (pending != null) { + if (!pending.isDone()) { + // Async task not done yet, return false to wait + return false; + } + // Async task done, clear the pending future and resume + LOG.debug("Async task done..."); + context.setPendingFuture(null); + } + + Continuation currentContinuation = context.getCurrentContinuation(); + if (currentContinuation == null) { + // First invocation: create new Continuation + LOG.debug("Create new continuation."); + currentContinuation = new Continuation(SCOPE, action); + context.setCurrentContinuation(currentContinuation); + } + + // Run the continuation. It returns either when the action completes or when it yields + // inside executeAsync; in the latter case we return false and let the next executeAction + // call observe pendingFuture completion and resume. + currentContinuation.run(); + + if (currentContinuation.isDone()) { + // Continuation completed + context.setCurrentContinuation(null); + LOG.debug("Current continuation is done."); + return true; + } else { + // Continuation yielded, waiting for async task + // pendingFuture should have been set by executeAsync + LOG.debug("Current continuation still running."); + return false; + } + } + + /** + * Asynchronously executes the provided supplier using Continuation. + * + * <p>This method submits the task to a thread pool and yields the Continuation. The next call + * to executeAction will check if the Future is done and resume accordingly. + * + * @param context the continuation context for this action + * @param supplier the supplier to execute + * @param <T> the result type + * @return the result of the supplier + * @throws Exception if the async execution fails + */ + @SuppressWarnings("unchecked") + public <T> T executeAsync(ContinuationContext context, Supplier<T> supplier) throws Exception { + // Clear previous state + context.clearAsyncState(); + + // Submit task to thread pool and store the Future + Future<?> future = + asyncExecutor.submit( + () -> { + try { + T result = supplier.get(); + context.getAsyncResultRef().set(result); + } catch (Throwable t) { + context.getAsyncExceptionRef().set(t); + } + }); + + // Store the future reference before yielding (volatile write ensures visibility) + context.setPendingFuture(future); + + // Yield until the future is done + while (!future.isDone()) { + Continuation.yield(SCOPE); + } + + // Check for exception from the async task + Throwable exception = context.getAsyncExceptionRef().get(); + if (exception != null) { + if (exception instanceof Exception) { + throw (Exception) exception; + } else if (exception instanceof Error) { + throw (Error) exception; + } else { + throw new RuntimeException(exception); + } + } + + return (T) context.getAsyncResultRef().get(); + } + + public void close() { + asyncExecutor.shutdownNow(); + } + + /** + * Returns whether continuation-based async execution is supported. + * + * @return true (this is the JDK 21+ version) + */ + public static boolean isContinuationSupported() { + return true; + } +} diff --git a/runtime/src/main/java21/org/apache/flink/agents/runtime/async/ContinuationContext.java b/runtime/src/main/java21/org/apache/flink/agents/runtime/async/ContinuationContext.java new file mode 100644 index 00000000..fbff2b11 --- /dev/null +++ b/runtime/src/main/java21/org/apache/flink/agents/runtime/async/ContinuationContext.java @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.agents.runtime.async; + +import jdk.internal.vm.Continuation; + +import java.util.concurrent.Future; +import java.util.concurrent.atomic.AtomicReference; + +/** Continuation context with JDK 21 continuation state. */ +public class ContinuationContext { + + private Continuation currentContinuation; + private volatile Future<?> pendingFuture; + private final AtomicReference<Object> asyncResult = new AtomicReference<>(); + private final AtomicReference<Throwable> asyncException = new AtomicReference<>(); + + public Continuation getCurrentContinuation() { + return currentContinuation; + } + + public void setCurrentContinuation(Continuation currentContinuation) { + this.currentContinuation = currentContinuation; + } + + public Future<?> getPendingFuture() { + return pendingFuture; + } + + public void setPendingFuture(Future<?> pendingFuture) { + this.pendingFuture = pendingFuture; + } + + public AtomicReference<Object> getAsyncResultRef() { + return asyncResult; + } + + public AtomicReference<Throwable> getAsyncExceptionRef() { + return asyncException; + } + + public void clearAsyncState() { + pendingFuture = null; + asyncResult.set(null); + asyncException.set(null); + } +} diff --git a/runtime/src/test/java/org/apache/flink/agents/runtime/memory/MemoryRefTest.java b/runtime/src/test/java/org/apache/flink/agents/runtime/memory/MemoryRefTest.java index 52a5ff33..e1636155 100644 --- a/runtime/src/test/java/org/apache/flink/agents/runtime/memory/MemoryRefTest.java +++ b/runtime/src/test/java/org/apache/flink/agents/runtime/memory/MemoryRefTest.java @@ -18,6 +18,7 @@ package org.apache.flink.agents.runtime.memory; import org.apache.flink.agents.api.configuration.ReadableConfiguration; +import org.apache.flink.agents.api.context.DurableCallable; import org.apache.flink.agents.api.context.MemoryObject; import org.apache.flink.agents.api.context.MemoryRef; import org.apache.flink.agents.api.context.RunnerContext; @@ -117,6 +118,16 @@ public class MemoryRefTest { return null; } + @Override + public <T> T durableExecute(DurableCallable<T> callable) throws Exception { + return callable.call(); + } + + @Override + public <T> T durableExecuteAsync(DurableCallable<T> callable) throws Exception { + return callable.call(); + } + @Override public void close() throws Exception {} } diff --git a/runtime/src/test/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperatorTest.java b/runtime/src/test/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperatorTest.java index 4c58f98e..f098acd8 100644 --- a/runtime/src/test/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperatorTest.java +++ b/runtime/src/test/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperatorTest.java @@ -20,6 +20,7 @@ package org.apache.flink.agents.runtime.operator; import org.apache.flink.agents.api.Event; import org.apache.flink.agents.api.InputEvent; import org.apache.flink.agents.api.OutputEvent; +import org.apache.flink.agents.api.context.DurableCallable; import org.apache.flink.agents.api.context.MemoryObject; import org.apache.flink.agents.api.context.RunnerContext; import org.apache.flink.agents.plan.AgentPlan; @@ -44,6 +45,7 @@ import java.util.Map; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.Future; +import java.util.stream.Collectors; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; @@ -489,8 +491,270 @@ public class ActionExecutionOperatorTest { } } + /** Tests that executeAsync works correctly. */ + @Test + void testExecuteAsyncJavaAction() throws Exception { + try (KeyedOneInputStreamOperatorTestHarness<Long, Long, Object> testHarness = + new KeyedOneInputStreamOperatorTestHarness<>( + new ActionExecutionOperatorFactory( + TestAgent.getAsyncAgentPlan(false), true), + (KeySelector<Long, Long>) value -> value, + TypeInformation.of(Long.class))) { + testHarness.open(); + ActionExecutionOperator<Long, Object> operator = + (ActionExecutionOperator<Long, Object>) testHarness.getOperator(); + + // Input value 5: asyncAction1 computes 5 * 10 = 50, action2 computes 50 * 2 = 100 + testHarness.processElement(new StreamRecord<>(5L)); + operator.waitInFlightEventsFinished(); + + List<StreamRecord<Object>> recordOutput = + (List<StreamRecord<Object>>) testHarness.getRecordOutput(); + assertThat(recordOutput.size()).isEqualTo(1); + assertThat(recordOutput.get(0).getValue()).isEqualTo(100L); + } + } + + /** + * Tests that multiple executeAsync calls can be chained together. Each async operation should + * complete before the next one starts (serial execution). + */ + @Test + void testMultipleExecuteAsyncCalls() throws Exception { + try (KeyedOneInputStreamOperatorTestHarness<Long, Long, Object> testHarness = + new KeyedOneInputStreamOperatorTestHarness<>( + new ActionExecutionOperatorFactory(TestAgent.getAsyncAgentPlan(true), true), + (KeySelector<Long, Long>) value -> value, + TypeInformation.of(Long.class))) { + testHarness.open(); + ActionExecutionOperator<Long, Object> operator = + (ActionExecutionOperator<Long, Object>) testHarness.getOperator(); + + // Input value 7: + // First async: 7 + 100 = 107 + // Second async: 107 * 2 = 214 + testHarness.processElement(new StreamRecord<>(7L)); + operator.waitInFlightEventsFinished(); + + List<StreamRecord<Object>> recordOutput = + (List<StreamRecord<Object>>) testHarness.getRecordOutput(); + assertThat(recordOutput.size()).isEqualTo(1); + assertThat(recordOutput.get(0).getValue()).isEqualTo(214L); + } + } + + /** + * Tests that executeAsync works correctly with multiple keys processed concurrently. Each key + * should complete its async operations independently. + */ + @Test + void testExecuteAsyncWithMultipleKeys() throws Exception { + try (KeyedOneInputStreamOperatorTestHarness<Long, Long, Object> testHarness = + new KeyedOneInputStreamOperatorTestHarness<>( + new ActionExecutionOperatorFactory( + TestAgent.getAsyncAgentPlan(false), true), + (KeySelector<Long, Long>) value -> value, + TypeInformation.of(Long.class))) { + testHarness.open(); + ActionExecutionOperator<Long, Object> operator = + (ActionExecutionOperator<Long, Object>) testHarness.getOperator(); + + // Process two elements with different keys + // Key 3: asyncAction1 computes 3 * 10 = 30, action2 computes 30 * 2 = 60 + // Key 4: asyncAction1 computes 4 * 10 = 40, action2 computes 40 * 2 = 80 + testHarness.processElement(new StreamRecord<>(3L)); + testHarness.processElement(new StreamRecord<>(4L)); + operator.waitInFlightEventsFinished(); + + List<StreamRecord<Object>> recordOutput = + (List<StreamRecord<Object>>) testHarness.getRecordOutput(); + assertThat(recordOutput.size()).isEqualTo(2); + + // Check both outputs exist (order may vary due to concurrent processing) + List<Object> outputValues = + recordOutput.stream().map(StreamRecord::getValue).collect(Collectors.toList()); + assertThat(outputValues).containsExactlyInAnyOrder(60L, 80L); + } + } + + /** Tests that durableExecute (sync) works correctly. */ + @Test + void testDurableExecuteSyncAction() throws Exception { + try (KeyedOneInputStreamOperatorTestHarness<Long, Long, Object> testHarness = + new KeyedOneInputStreamOperatorTestHarness<>( + new ActionExecutionOperatorFactory( + TestAgent.getDurableSyncAgentPlan(), true), + (KeySelector<Long, Long>) value -> value, + TypeInformation.of(Long.class))) { + testHarness.open(); + ActionExecutionOperator<Long, Object> operator = + (ActionExecutionOperator<Long, Object>) testHarness.getOperator(); + + // Input value 5: durableSyncAction computes 5 * 3 = 15 + testHarness.processElement(new StreamRecord<>(5L)); + operator.waitInFlightEventsFinished(); + + List<StreamRecord<Object>> recordOutput = + (List<StreamRecord<Object>>) testHarness.getRecordOutput(); + assertThat(recordOutput.size()).isEqualTo(1); + assertThat(recordOutput.get(0).getValue()).isEqualTo(15L); + } + } + + /** + * Tests that durableExecute with ActionStateStore can recover from cached results. This + * verifies that on recovery, the durable execution returns cached results without re-executing + * the supplier. + */ + @Test + void testDurableExecuteRecoveryFromCachedResult() throws Exception { + AgentPlan agentPlan = TestAgent.getDurableSyncAgentPlan(); + InMemoryActionStateStore actionStateStore = new InMemoryActionStateStore(false); + + // Reset the counter before the test + TestAgent.DURABLE_CALL_COUNTER.set(0); + + // First execution - will execute the supplier and store the result + try (KeyedOneInputStreamOperatorTestHarness<Long, Long, Object> testHarness = + new KeyedOneInputStreamOperatorTestHarness<>( + new ActionExecutionOperatorFactory<>(agentPlan, true, actionStateStore), + (KeySelector<Long, Long>) value -> value, + TypeInformation.of(Long.class))) { + testHarness.open(); + ActionExecutionOperator<Long, Object> operator = + (ActionExecutionOperator<Long, Object>) testHarness.getOperator(); + + testHarness.processElement(new StreamRecord<>(7L)); + operator.waitInFlightEventsFinished(); + + List<StreamRecord<Object>> recordOutput = + (List<StreamRecord<Object>>) testHarness.getRecordOutput(); + assertThat(recordOutput.size()).isEqualTo(1); + // 7 * 3 = 21 + assertThat(recordOutput.get(0).getValue()).isEqualTo(21L); + + // Verify action state was stored + assertThat(actionStateStore.getKeyedActionStates()).isNotEmpty(); + + // Verify supplier was called exactly once during first execution + assertThat(TestAgent.DURABLE_CALL_COUNTER.get()).isEqualTo(1); + } + + // Second execution with same action state store - should recover from cached result + try (KeyedOneInputStreamOperatorTestHarness<Long, Long, Object> testHarness = + new KeyedOneInputStreamOperatorTestHarness<>( + new ActionExecutionOperatorFactory<>(agentPlan, true, actionStateStore), + (KeySelector<Long, Long>) value -> value, + TypeInformation.of(Long.class))) { + testHarness.open(); + ActionExecutionOperator<Long, Object> operator = + (ActionExecutionOperator<Long, Object>) testHarness.getOperator(); + + // Process the same key - should recover from cached state + testHarness.processElement(new StreamRecord<>(7L)); + operator.waitInFlightEventsFinished(); + + List<StreamRecord<Object>> recordOutput = + (List<StreamRecord<Object>>) testHarness.getRecordOutput(); + assertThat(recordOutput.size()).isEqualTo(1); + // Should get the same result (21) from recovery + assertThat(recordOutput.get(0).getValue()).isEqualTo(21L); + + // CRITICAL: Verify supplier was NOT called during recovery - counter should still be 1 + assertThat(TestAgent.DURABLE_CALL_COUNTER.get()) + .as("Supplier should NOT be called during recovery") + .isEqualTo(1); + } + } + + /** Tests that durableExecute properly handles exceptions thrown by the supplier. */ + @Test + void testDurableExecuteExceptionHandling() throws Exception { + try (KeyedOneInputStreamOperatorTestHarness<Long, Long, Object> testHarness = + new KeyedOneInputStreamOperatorTestHarness<>( + new ActionExecutionOperatorFactory( + TestAgent.getDurableExceptionAgentPlan(), true), + (KeySelector<Long, Long>) value -> value, + TypeInformation.of(Long.class))) { + testHarness.open(); + ActionExecutionOperator<Long, Object> operator = + (ActionExecutionOperator<Long, Object>) testHarness.getOperator(); + + // Reset counter + TestAgent.EXCEPTION_CALL_COUNTER.set(0); + + testHarness.processElement(new StreamRecord<>(1L)); + operator.waitInFlightEventsFinished(); + + List<StreamRecord<Object>> recordOutput = + (List<StreamRecord<Object>>) testHarness.getRecordOutput(); + assertThat(recordOutput.size()).isEqualTo(1); + // Verify the error was caught and handled + assertThat(recordOutput.get(0).getValue().toString()).contains("ERROR:"); + + // Verify the supplier was called + assertThat(TestAgent.EXCEPTION_CALL_COUNTER.get()).isEqualTo(1); + } + } + + /** + * Tests that exception recovery works correctly - on recovery, the cached exception should be + * re-thrown without calling the supplier again. + */ + @Test + void testDurableExecuteExceptionRecovery() throws Exception { + AgentPlan agentPlan = TestAgent.getDurableExceptionAgentPlan(); + InMemoryActionStateStore actionStateStore = new InMemoryActionStateStore(false); + + // Reset counter + TestAgent.EXCEPTION_CALL_COUNTER.set(0); + + // First execution - will execute the supplier, throw exception, and store it + try (KeyedOneInputStreamOperatorTestHarness<Long, Long, Object> testHarness = + new KeyedOneInputStreamOperatorTestHarness<>( + new ActionExecutionOperatorFactory<>(agentPlan, true, actionStateStore), + (KeySelector<Long, Long>) value -> value, + TypeInformation.of(Long.class))) { + testHarness.open(); + ActionExecutionOperator<Long, Object> operator = + (ActionExecutionOperator<Long, Object>) testHarness.getOperator(); + + testHarness.processElement(new StreamRecord<>(2L)); + operator.waitInFlightEventsFinished(); + + // Verify supplier was called once + assertThat(TestAgent.EXCEPTION_CALL_COUNTER.get()).isEqualTo(1); + + // Verify action state was stored + assertThat(actionStateStore.getKeyedActionStates()).isNotEmpty(); + } + + // Second execution - should recover cached exception without calling supplier + try (KeyedOneInputStreamOperatorTestHarness<Long, Long, Object> testHarness = + new KeyedOneInputStreamOperatorTestHarness<>( + new ActionExecutionOperatorFactory<>(agentPlan, true, actionStateStore), + (KeySelector<Long, Long>) value -> value, + TypeInformation.of(Long.class))) { + testHarness.open(); + ActionExecutionOperator<Long, Object> operator = + (ActionExecutionOperator<Long, Object>) testHarness.getOperator(); + + testHarness.processElement(new StreamRecord<>(2L)); + operator.waitInFlightEventsFinished(); + + // CRITICAL: Verify supplier was NOT called during recovery + assertThat(TestAgent.EXCEPTION_CALL_COUNTER.get()) + .as("Supplier should NOT be called during exception recovery") + .isEqualTo(1); + } + } + public static class TestAgent { + /** Counter to track how many times the durable supplier is executed. */ + public static final java.util.concurrent.atomic.AtomicInteger DURABLE_CALL_COUNTER = + new java.util.concurrent.atomic.AtomicInteger(0); + public static class MiddleEvent extends Event { public Long num; @@ -539,6 +803,157 @@ public class ActionExecutionOperatorTest { } } + public static void asyncAction1(InputEvent event, RunnerContext context) { + Long inputData = (Long) event.getInput(); + try { + Long result = + context.durableExecuteAsync( + new DurableCallable<Long>() { + @Override + public String getId() { + return "async-multiply"; + } + + @Override + public Class<Long> getResultClass() { + return Long.class; + } + + @Override + public Long call() { + try { + Thread.sleep(50); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + return inputData * 10; + } + }); + + MemoryObject mem = context.getShortTermMemory(); + mem.set("tmp", result); + context.sendEvent(new MiddleEvent(result)); + } catch (Exception e) { + ExceptionUtils.rethrow(e); + } + } + + public static void multiAsyncAction(InputEvent event, RunnerContext context) { + Long inputData = (Long) event.getInput(); + try { + Long result1 = + context.durableExecuteAsync( + new DurableCallable<Long>() { + @Override + public String getId() { + return "async-add"; + } + + @Override + public Class<Long> getResultClass() { + return Long.class; + } + + @Override + public Long call() { + try { + Thread.sleep(30); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + return inputData + 100; + } + }); + + Long result2 = + context.durableExecuteAsync( + new DurableCallable<Long>() { + @Override + public String getId() { + return "async-multiply"; + } + + @Override + public Class<Long> getResultClass() { + return Long.class; + } + + @Override + public Long call() { + try { + Thread.sleep(30); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + return result1 * 2; + } + }); + + MemoryObject mem = context.getShortTermMemory(); + mem.set("multiAsyncResult", result2); + context.sendEvent(new OutputEvent(result2)); + } catch (Exception e) { + ExceptionUtils.rethrow(e); + } + } + + public static void durableSyncAction(InputEvent event, RunnerContext context) { + Long inputData = (Long) event.getInput(); + try { + Long result = + context.durableExecute( + new DurableCallable<Long>() { + @Override + public String getId() { + return "sync-compute"; + } + + @Override + public Class<Long> getResultClass() { + return Long.class; + } + + @Override + public Long call() { + DURABLE_CALL_COUNTER.incrementAndGet(); + return inputData * 3; + } + }); + + context.sendEvent(new OutputEvent(result)); + } catch (Exception e) { + ExceptionUtils.rethrow(e); + } + } + + public static final java.util.concurrent.atomic.AtomicInteger EXCEPTION_CALL_COUNTER = + new java.util.concurrent.atomic.AtomicInteger(0); + + public static void durableExceptionAction(InputEvent event, RunnerContext context) { + try { + context.durableExecute( + new DurableCallable<String>() { + @Override + public String getId() { + return "exception-action"; + } + + @Override + public Class<String> getResultClass() { + return String.class; + } + + @Override + public String call() { + EXCEPTION_CALL_COUNTER.incrementAndGet(); + throw new RuntimeException("Test exception from durableExecute"); + } + }); + } catch (Exception e) { + context.sendEvent(new OutputEvent("ERROR:" + e.getMessage())); + } + } + public static AgentPlan getAgentPlan(boolean testMemoryAccessOutOfMailbox) { try { Map<String, List<Action>> actionsByEvent = new HashMap<>(); @@ -586,6 +1001,114 @@ public class ActionExecutionOperatorTest { } return null; } + + /** + * Creates an AgentPlan for testing async execution. + * + * @param useMultiAsync if true, uses multiAsyncAction which chains multiple async calls + * @return AgentPlan configured with async actions + */ + public static AgentPlan getAsyncAgentPlan(boolean useMultiAsync) { + try { + Map<String, List<Action>> actionsByEvent = new HashMap<>(); + Map<String, Action> actions = new HashMap<>(); + + if (useMultiAsync) { + // Use multiAsyncAction that chains multiple executeAsync calls + Action multiAsyncAction = + new Action( + "multiAsyncAction", + new JavaFunction( + TestAgent.class, + "multiAsyncAction", + new Class<?>[] {InputEvent.class, RunnerContext.class}), + Collections.singletonList(InputEvent.class.getName())); + actionsByEvent.put( + InputEvent.class.getName(), + Collections.singletonList(multiAsyncAction)); + actions.put(multiAsyncAction.getName(), multiAsyncAction); + } else { + // Use asyncAction1 -> action2 chain + Action asyncAction1 = + new Action( + "asyncAction1", + new JavaFunction( + TestAgent.class, + "asyncAction1", + new Class<?>[] {InputEvent.class, RunnerContext.class}), + Collections.singletonList(InputEvent.class.getName())); + Action action2 = + new Action( + "action2", + new JavaFunction( + TestAgent.class, + "action2", + new Class<?>[] { + MiddleEvent.class, RunnerContext.class + }), + Collections.singletonList(MiddleEvent.class.getName())); + actionsByEvent.put( + InputEvent.class.getName(), Collections.singletonList(asyncAction1)); + actionsByEvent.put( + MiddleEvent.class.getName(), Collections.singletonList(action2)); + actions.put(asyncAction1.getName(), asyncAction1); + actions.put(action2.getName(), action2); + } + + return new AgentPlan(actions, actionsByEvent, new HashMap<>()); + } catch (Exception e) { + ExceptionUtils.rethrow(e); + } + return null; + } + + public static AgentPlan getDurableSyncAgentPlan() { + try { + Map<String, List<Action>> actionsByEvent = new HashMap<>(); + Map<String, Action> actions = new HashMap<>(); + + Action durableSyncAction = + new Action( + "durableSyncAction", + new JavaFunction( + TestAgent.class, + "durableSyncAction", + new Class<?>[] {InputEvent.class, RunnerContext.class}), + Collections.singletonList(InputEvent.class.getName())); + actionsByEvent.put( + InputEvent.class.getName(), Collections.singletonList(durableSyncAction)); + actions.put(durableSyncAction.getName(), durableSyncAction); + + return new AgentPlan(actions, actionsByEvent, new HashMap<>()); + } catch (Exception e) { + ExceptionUtils.rethrow(e); + } + return null; + } + + public static AgentPlan getDurableExceptionAgentPlan() { + try { + Map<String, List<Action>> actionsByEvent = new HashMap<>(); + Map<String, Action> actions = new HashMap<>(); + + Action exceptionAction = + new Action( + "durableExceptionAction", + new JavaFunction( + TestAgent.class, + "durableExceptionAction", + new Class<?>[] {InputEvent.class, RunnerContext.class}), + Collections.singletonList(InputEvent.class.getName())); + actionsByEvent.put( + InputEvent.class.getName(), Collections.singletonList(exceptionAction)); + actions.put(exceptionAction.getName(), exceptionAction); + + return new AgentPlan(actions, actionsByEvent, new HashMap<>()); + } catch (Exception e) { + ExceptionUtils.rethrow(e); + } + return null; + } } private static void assertMailboxSizeAndRun(TaskMailbox mailbox, int expectedSize)
