joeyutong commented on code in PR #628: URL: https://github.com/apache/flink-agents/pull/628#discussion_r3159701450
########## runtime/src/main/java/org/apache/flink/agents/runtime/actionstate/FlussActionStateStore.java: ########## @@ -0,0 +1,489 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.agents.runtime.actionstate; + +import org.apache.flink.agents.api.Event; +import org.apache.flink.agents.plan.AgentConfiguration; +import org.apache.flink.agents.plan.actions.Action; +import org.apache.flink.annotation.VisibleForTesting; +import org.apache.fluss.client.Connection; +import org.apache.fluss.client.ConnectionFactory; +import org.apache.fluss.client.admin.Admin; +import org.apache.fluss.client.admin.OffsetSpec; +import org.apache.fluss.client.table.Table; +import org.apache.fluss.client.table.scanner.ScanRecord; +import org.apache.fluss.client.table.scanner.log.LogScanner; +import org.apache.fluss.client.table.scanner.log.ScanRecords; +import org.apache.fluss.client.table.writer.AppendWriter; +import org.apache.fluss.config.ConfigOptions; +import org.apache.fluss.config.Configuration; +import org.apache.fluss.metadata.DatabaseDescriptor; +import org.apache.fluss.metadata.Schema; +import org.apache.fluss.metadata.TableBucket; +import org.apache.fluss.metadata.TableDescriptor; +import org.apache.fluss.metadata.TablePath; +import org.apache.fluss.row.BinaryString; +import org.apache.fluss.row.GenericRow; +import org.apache.fluss.row.InternalRow; +import org.apache.fluss.types.DataTypes; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.time.Duration; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.apache.flink.agents.api.configuration.AgentConfigOptions.FLUSS_ACTION_STATE_DATABASE; +import static org.apache.flink.agents.api.configuration.AgentConfigOptions.FLUSS_ACTION_STATE_TABLE; +import static org.apache.flink.agents.api.configuration.AgentConfigOptions.FLUSS_ACTION_STATE_TABLE_BUCKETS; +import static org.apache.flink.agents.api.configuration.AgentConfigOptions.FLUSS_BOOTSTRAP_SERVERS; +import static org.apache.flink.agents.api.configuration.AgentConfigOptions.FLUSS_SASL_JAAS_CONFIG; +import static org.apache.flink.agents.api.configuration.AgentConfigOptions.FLUSS_SASL_MECHANISM; +import static org.apache.flink.agents.api.configuration.AgentConfigOptions.FLUSS_SASL_PASSWORD; +import static org.apache.flink.agents.api.configuration.AgentConfigOptions.FLUSS_SASL_USERNAME; +import static org.apache.flink.agents.api.configuration.AgentConfigOptions.FLUSS_SECURITY_PROTOCOL; +import static org.apache.flink.agents.runtime.actionstate.ActionStateUtil.generateKey; +import static org.apache.fluss.config.ConfigOptions.BOOTSTRAP_SERVERS; +import static org.apache.fluss.config.ConfigOptions.CLIENT_SASL_JAAS_CONFIG; +import static org.apache.fluss.config.ConfigOptions.CLIENT_SASL_JAAS_PASSWORD; +import static org.apache.fluss.config.ConfigOptions.CLIENT_SASL_JAAS_USERNAME; +import static org.apache.fluss.config.ConfigOptions.CLIENT_SASL_MECHANISM; +import static org.apache.fluss.config.ConfigOptions.CLIENT_SECURITY_PROTOCOL; + +/** + * An implementation of {@link ActionStateStore} that uses an Apache Fluss log table as the backend. + * All state is maintained in an in-memory map for fast lookups, with the Fluss log table providing + * durability and recovery support. + */ +public class FlussActionStateStore implements ActionStateStore { + + private static final Logger LOG = LoggerFactory.getLogger(FlussActionStateStore.class); + + private static final Duration POLL_TIMEOUT = Duration.ofSeconds(1); + + // Column names in the Fluss table schema + private static final String COL_NAME_STATE_KEY = "state_key"; + private static final String COL_NAME_STATE_PAYLOAD = "state_payload"; + private static final String COL_NAME_AGENT_KEY = "agent_key"; + + // Column indices in the Fluss table schema + private static final int COL_STATE_KEY = 0; + private static final int COL_STATE_PAYLOAD = 1; + + private final AgentConfiguration agentConfiguration; + private final String databaseName; + private final String tableName; + private final TablePath tablePath; + + private final Connection connection; + private final Table table; + private final AppendWriter writer; + + /** In-memory cache for O(1) state lookups; rebuilt from Fluss log on recovery. */ + private final Map<String, ActionState> actionStates; + + @VisibleForTesting + FlussActionStateStore( + Map<String, ActionState> actionStates, + Connection connection, + Table table, + AppendWriter writer) { + this.agentConfiguration = null; + this.databaseName = null; + this.tableName = null; + this.tablePath = null; + this.actionStates = actionStates; + this.connection = connection; + this.table = table; + this.writer = writer; + } + + public FlussActionStateStore(AgentConfiguration agentConfiguration) { + this.agentConfiguration = agentConfiguration; + this.databaseName = agentConfiguration.get(FLUSS_ACTION_STATE_DATABASE); + this.tableName = agentConfiguration.get(FLUSS_ACTION_STATE_TABLE); + this.tablePath = TablePath.of(databaseName, tableName); + this.actionStates = new HashMap<>(); + + Configuration flussConf = new Configuration(); + flussConf.setString( + BOOTSTRAP_SERVERS.key(), agentConfiguration.get(FLUSS_BOOTSTRAP_SERVERS)); + // Minimize latency for synchronous put(): setting batch linger time to zero ensures + // that each append is sent immediately without waiting for additional records to batch. + flussConf.set(ConfigOptions.CLIENT_WRITER_BATCH_TIMEOUT, Duration.ZERO); + + flussConf.setString( + CLIENT_SECURITY_PROTOCOL, agentConfiguration.get(FLUSS_SECURITY_PROTOCOL)); + + flussConf.setString(CLIENT_SASL_MECHANISM, agentConfiguration.get(FLUSS_SASL_MECHANISM)); + + String jaasConfig = agentConfiguration.get(FLUSS_SASL_JAAS_CONFIG); + if (jaasConfig != null) { + flussConf.setString(CLIENT_SASL_JAAS_CONFIG, jaasConfig); + } + String username = agentConfiguration.get(FLUSS_SASL_USERNAME); + if (username != null) { + flussConf.setString(CLIENT_SASL_JAAS_USERNAME, username); + } + String password = agentConfiguration.get(FLUSS_SASL_PASSWORD); + if (password != null) { + flussConf.setString(CLIENT_SASL_JAAS_PASSWORD, password); + } + + this.connection = ConnectionFactory.createConnection(flussConf); + maybeCreateDatabaseAndTable(); + this.table = connection.getTable(tablePath); + this.writer = table.newAppend().createWriter(); + + LOG.info( + "Initialized FlussActionStateStore (log table) with table: {}.{}", + databaseName, + tableName); + } + + @Override + public void put(Object key, long seqNum, Action action, Event event, ActionState state) + throws Exception { + String stateKey = generateKey(key, seqNum, action, event); + byte[] payload = ActionStateSerde.serialize(state); + + GenericRow row = + GenericRow.of( + BinaryString.fromString(stateKey), + payload, + BinaryString.fromString(key.toString())); + + // Append to Fluss log for durability, then update in-memory cache. + // Synchronous write ensures the record is durable before returning. + writer.append(row).get(); + actionStates.put(stateKey, state); + + LOG.debug("Stored action state: key={}, isCompleted={}", stateKey, state.isCompleted()); + } + + @Override + public ActionState get(Object key, long seqNum, Action action, Event event) throws Exception { + String stateKey = generateKey(key, seqNum, action, event); + + boolean hasDivergence = checkDivergence(key.toString(), seqNum); + + if (!actionStates.containsKey(stateKey) || hasDivergence) { + actionStates + .entrySet() + .removeIf( + entry -> { + try { + List<String> parts = ActionStateUtil.parseKey(entry.getKey()); + if (parts.size() >= 2) { + long stateSeqNum = Long.parseLong(parts.get(1)); + return stateSeqNum > seqNum; + } + } catch (NumberFormatException e) { + LOG.warn( + "Failed to parse sequence number from state key: {}", + stateKey); + } + return false; + }); + } + + ActionState state = actionStates.get(stateKey); + LOG.debug("Lookup action state: key={}, found={}", stateKey, state != null); + return state; + } + + private boolean checkDivergence(String key, long seqNum) { + return actionStates.keySet().stream() + .filter(k -> k.startsWith(key + "_" + seqNum + "_")) + .count() + > 1; + } + + /** + * Rebuilds in-memory state by scanning the Fluss log table. If recovery markers are provided, + * computes the minimum offset per bucket across all markers and subscribes from those offsets. + * Otherwise, skips rebuild since there is no checkpointed position to recover from. Reads from + * the start offset up to the latest offset captured at rebuild start. For the same state key + * appearing multiple times in the log, the latest record wins (last-write-wins). + */ + @Override + public void rebuildState(List<Object> recoveryMarkers) { Review Comment: Could we split `rebuildState()` into a few smaller helpers? It currently handles marker merging, effective offset calculation, scanner subscription, bounded polling, and row deserialization in one method. Extracting pieces such as `mergeRecoveryMarkerOffsets`, `subscribeEffectiveOffsets`, and the scan/replay loop would make the recovery boundary easier to review and reduce the risk of subtle offset bugs. ########## runtime/src/main/java/org/apache/flink/agents/runtime/actionstate/FlussActionStateStore.java: ########## @@ -0,0 +1,489 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.agents.runtime.actionstate; + +import org.apache.flink.agents.api.Event; +import org.apache.flink.agents.plan.AgentConfiguration; +import org.apache.flink.agents.plan.actions.Action; +import org.apache.flink.annotation.VisibleForTesting; +import org.apache.fluss.client.Connection; +import org.apache.fluss.client.ConnectionFactory; +import org.apache.fluss.client.admin.Admin; +import org.apache.fluss.client.admin.OffsetSpec; +import org.apache.fluss.client.table.Table; +import org.apache.fluss.client.table.scanner.ScanRecord; +import org.apache.fluss.client.table.scanner.log.LogScanner; +import org.apache.fluss.client.table.scanner.log.ScanRecords; +import org.apache.fluss.client.table.writer.AppendWriter; +import org.apache.fluss.config.ConfigOptions; +import org.apache.fluss.config.Configuration; +import org.apache.fluss.metadata.DatabaseDescriptor; +import org.apache.fluss.metadata.Schema; +import org.apache.fluss.metadata.TableBucket; +import org.apache.fluss.metadata.TableDescriptor; +import org.apache.fluss.metadata.TablePath; +import org.apache.fluss.row.BinaryString; +import org.apache.fluss.row.GenericRow; +import org.apache.fluss.row.InternalRow; +import org.apache.fluss.types.DataTypes; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.time.Duration; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.apache.flink.agents.api.configuration.AgentConfigOptions.FLUSS_ACTION_STATE_DATABASE; +import static org.apache.flink.agents.api.configuration.AgentConfigOptions.FLUSS_ACTION_STATE_TABLE; +import static org.apache.flink.agents.api.configuration.AgentConfigOptions.FLUSS_ACTION_STATE_TABLE_BUCKETS; +import static org.apache.flink.agents.api.configuration.AgentConfigOptions.FLUSS_BOOTSTRAP_SERVERS; +import static org.apache.flink.agents.api.configuration.AgentConfigOptions.FLUSS_SASL_JAAS_CONFIG; +import static org.apache.flink.agents.api.configuration.AgentConfigOptions.FLUSS_SASL_MECHANISM; +import static org.apache.flink.agents.api.configuration.AgentConfigOptions.FLUSS_SASL_PASSWORD; +import static org.apache.flink.agents.api.configuration.AgentConfigOptions.FLUSS_SASL_USERNAME; +import static org.apache.flink.agents.api.configuration.AgentConfigOptions.FLUSS_SECURITY_PROTOCOL; +import static org.apache.flink.agents.runtime.actionstate.ActionStateUtil.generateKey; +import static org.apache.fluss.config.ConfigOptions.BOOTSTRAP_SERVERS; +import static org.apache.fluss.config.ConfigOptions.CLIENT_SASL_JAAS_CONFIG; +import static org.apache.fluss.config.ConfigOptions.CLIENT_SASL_JAAS_PASSWORD; +import static org.apache.fluss.config.ConfigOptions.CLIENT_SASL_JAAS_USERNAME; +import static org.apache.fluss.config.ConfigOptions.CLIENT_SASL_MECHANISM; +import static org.apache.fluss.config.ConfigOptions.CLIENT_SECURITY_PROTOCOL; + +/** + * An implementation of {@link ActionStateStore} that uses an Apache Fluss log table as the backend. + * All state is maintained in an in-memory map for fast lookups, with the Fluss log table providing + * durability and recovery support. + */ +public class FlussActionStateStore implements ActionStateStore { + + private static final Logger LOG = LoggerFactory.getLogger(FlussActionStateStore.class); + + private static final Duration POLL_TIMEOUT = Duration.ofSeconds(1); + + // Column names in the Fluss table schema + private static final String COL_NAME_STATE_KEY = "state_key"; + private static final String COL_NAME_STATE_PAYLOAD = "state_payload"; + private static final String COL_NAME_AGENT_KEY = "agent_key"; + + // Column indices in the Fluss table schema + private static final int COL_STATE_KEY = 0; + private static final int COL_STATE_PAYLOAD = 1; + + private final AgentConfiguration agentConfiguration; + private final String databaseName; + private final String tableName; + private final TablePath tablePath; + + private final Connection connection; + private final Table table; + private final AppendWriter writer; + + /** In-memory cache for O(1) state lookups; rebuilt from Fluss log on recovery. */ + private final Map<String, ActionState> actionStates; + + @VisibleForTesting + FlussActionStateStore( + Map<String, ActionState> actionStates, + Connection connection, + Table table, + AppendWriter writer) { + this.agentConfiguration = null; + this.databaseName = null; + this.tableName = null; + this.tablePath = null; + this.actionStates = actionStates; + this.connection = connection; + this.table = table; + this.writer = writer; + } + + public FlussActionStateStore(AgentConfiguration agentConfiguration) { + this.agentConfiguration = agentConfiguration; + this.databaseName = agentConfiguration.get(FLUSS_ACTION_STATE_DATABASE); + this.tableName = agentConfiguration.get(FLUSS_ACTION_STATE_TABLE); + this.tablePath = TablePath.of(databaseName, tableName); + this.actionStates = new HashMap<>(); + + Configuration flussConf = new Configuration(); + flussConf.setString( + BOOTSTRAP_SERVERS.key(), agentConfiguration.get(FLUSS_BOOTSTRAP_SERVERS)); + // Minimize latency for synchronous put(): setting batch linger time to zero ensures + // that each append is sent immediately without waiting for additional records to batch. + flussConf.set(ConfigOptions.CLIENT_WRITER_BATCH_TIMEOUT, Duration.ZERO); + + flussConf.setString( + CLIENT_SECURITY_PROTOCOL, agentConfiguration.get(FLUSS_SECURITY_PROTOCOL)); + + flussConf.setString(CLIENT_SASL_MECHANISM, agentConfiguration.get(FLUSS_SASL_MECHANISM)); + + String jaasConfig = agentConfiguration.get(FLUSS_SASL_JAAS_CONFIG); + if (jaasConfig != null) { + flussConf.setString(CLIENT_SASL_JAAS_CONFIG, jaasConfig); + } + String username = agentConfiguration.get(FLUSS_SASL_USERNAME); + if (username != null) { + flussConf.setString(CLIENT_SASL_JAAS_USERNAME, username); + } + String password = agentConfiguration.get(FLUSS_SASL_PASSWORD); + if (password != null) { + flussConf.setString(CLIENT_SASL_JAAS_PASSWORD, password); + } + + this.connection = ConnectionFactory.createConnection(flussConf); + maybeCreateDatabaseAndTable(); + this.table = connection.getTable(tablePath); + this.writer = table.newAppend().createWriter(); + + LOG.info( + "Initialized FlussActionStateStore (log table) with table: {}.{}", + databaseName, + tableName); + } + + @Override + public void put(Object key, long seqNum, Action action, Event event, ActionState state) + throws Exception { + String stateKey = generateKey(key, seqNum, action, event); + byte[] payload = ActionStateSerde.serialize(state); + + GenericRow row = + GenericRow.of( + BinaryString.fromString(stateKey), + payload, + BinaryString.fromString(key.toString())); + + // Append to Fluss log for durability, then update in-memory cache. + // Synchronous write ensures the record is durable before returning. + writer.append(row).get(); + actionStates.put(stateKey, state); + + LOG.debug("Stored action state: key={}, isCompleted={}", stateKey, state.isCompleted()); + } + + @Override + public ActionState get(Object key, long seqNum, Action action, Event event) throws Exception { + String stateKey = generateKey(key, seqNum, action, event); + + boolean hasDivergence = checkDivergence(key.toString(), seqNum); + + if (!actionStates.containsKey(stateKey) || hasDivergence) { + actionStates + .entrySet() + .removeIf( + entry -> { + try { + List<String> parts = ActionStateUtil.parseKey(entry.getKey()); + if (parts.size() >= 2) { + long stateSeqNum = Long.parseLong(parts.get(1)); + return stateSeqNum > seqNum; + } + } catch (NumberFormatException e) { + LOG.warn( + "Failed to parse sequence number from state key: {}", + stateKey); + } + return false; + }); + } + + ActionState state = actionStates.get(stateKey); + LOG.debug("Lookup action state: key={}, found={}", stateKey, state != null); + return state; + } + + private boolean checkDivergence(String key, long seqNum) { + return actionStates.keySet().stream() + .filter(k -> k.startsWith(key + "_" + seqNum + "_")) + .count() + > 1; + } + + /** + * Rebuilds in-memory state by scanning the Fluss log table. If recovery markers are provided, + * computes the minimum offset per bucket across all markers and subscribes from those offsets. + * Otherwise, skips rebuild since there is no checkpointed position to recover from. Reads from + * the start offset up to the latest offset captured at rebuild start. For the same state key + * appearing multiple times in the log, the latest record wins (last-write-wins). + */ + @Override + public void rebuildState(List<Object> recoveryMarkers) { + LOG.info( + "Rebuilding action state from Fluss log table with {} recovery markers", + recoveryMarkers.size()); + + if (recoveryMarkers.isEmpty()) { + LOG.info("No recovery markers, skipping state rebuild"); + return; + } + + actionStates.clear(); + + // Compute per-bucket start offsets from recovery markers + Map<Integer, Long> bucketStartOffsets = new HashMap<>(); + for (Object marker : recoveryMarkers) { + if (marker instanceof Map) { + @SuppressWarnings("unchecked") + Map<Integer, Long> markerMap = (Map<Integer, Long>) marker; + for (Map.Entry<Integer, Long> entry : markerMap.entrySet()) { + bucketStartOffsets.merge(entry.getKey(), entry.getValue(), Math::min); + } + } else if (marker != null) { + LOG.warn( + "Ignoring unrecognized recovery marker type: {}", + marker.getClass().getName()); + } + } + + if (bucketStartOffsets.isEmpty()) { + LOG.info("No valid bucket offsets in recovery markers, skipping state rebuild"); + return; + } + + // Capture the latest offsets as the stopping point for each bucket + Map<Integer, Long> bucketEndOffsets = getBucketEndOffsets(); + // Capture the earliest available offsets so we can skip buckets whose data has + // been fully cleaned by log retention (earliestOffset >= endOffset). + Map<Integer, Long> bucketEarliestOffsets = getBucketEarliestOffsets(); + LOG.debug( + "Rebuild window: startOffsets={}, earliestOffsets={}, endOffsets={}", + bucketStartOffsets, + bucketEarliestOffsets, + bucketEndOffsets); + + try (LogScanner scanner = table.newScan().createLogScanner()) { + // Track which buckets still need to be consumed + Map<Integer, Long> remainingBuckets = new HashMap<>(); + for (Map.Entry<Integer, Long> entry : bucketStartOffsets.entrySet()) { + int bucket = entry.getKey(); + long startOffset = entry.getValue(); + Long endOffset = bucketEndOffsets.get(bucket); + + // Bucket referenced in recovery marker does not exist in current table + if (endOffset == null) { + LOG.warn( + "Bucket {} referenced in recovery marker does not exist " + + "in current table, state recovery for this bucket is skipped", + bucket); + continue; + } + + // No new data since checkpoint (includes empty buckets that never had writes) + if (endOffset <= startOffset) { + LOG.info( + "Skipping bucket {} for rebuild: no new data " + + "(endOffset={} <= startOffset={})", + bucket, + endOffset, + startOffset); + continue; + } + + // Check if retention has cleaned data in the recovery window + Long earliestOffset = bucketEarliestOffsets.get(bucket); + long effectiveStart = startOffset; + if (earliestOffset != null && earliestOffset > startOffset) { + effectiveStart = earliestOffset; + if (effectiveStart >= endOffset) { + // All data in recovery window has been cleaned by retention + LOG.warn( + "Bucket {} state recovery failed: all data between offset {} " + + "and {} has been cleaned by log retention " + + "(earliest available: {})", + bucket, + startOffset, + endOffset, + earliestOffset); + continue; + } + // Partial data loss: some state between startOffset and earliestOffset is gone + LOG.warn( + "Bucket {} partial state loss: data between offset {} and {} " + + "has been cleaned by log retention, " + + "recovering from offset {} instead", + bucket, + startOffset, + earliestOffset, + effectiveStart); + } + + scanner.subscribe(bucket, effectiveStart); + remainingBuckets.put(bucket, endOffset); + } + LOG.debug("Subscribed buckets for rebuild: {}", remainingBuckets); + + while (!remainingBuckets.isEmpty()) { + + ScanRecords records = scanner.poll(POLL_TIMEOUT); + for (TableBucket bucket : records.buckets()) { + Long endOffset = remainingBuckets.get(bucket.getBucket()); + if (endOffset == null) { + continue; + } + // Track the highest offset seen in this batch (including skipped records + // beyond endOffset) so that we can detect when the bucket has been fully + // consumed even if the last records are past our target window. + long lastSeenOffset = -1; + for (ScanRecord record : records.records(bucket)) { + lastSeenOffset = record.logOffset(); + if (record.logOffset() >= endOffset) { Review Comment: Once we see a `endOffset` record for a bucket, shouldn't we stop scanning that bucket instead of continuing through the rest of the records? Offsets within a bucket should be ordered, and records at or beyond `endOffset` are outside the recovery window. Using `break` here would make the bounded-replay semantics clearer and avoid iterating over records that should never be restored. ########## runtime/src/main/java/org/apache/flink/agents/runtime/actionstate/FlussActionStateStore.java: ########## @@ -0,0 +1,489 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.agents.runtime.actionstate; + +import org.apache.flink.agents.api.Event; +import org.apache.flink.agents.plan.AgentConfiguration; +import org.apache.flink.agents.plan.actions.Action; +import org.apache.flink.annotation.VisibleForTesting; +import org.apache.fluss.client.Connection; +import org.apache.fluss.client.ConnectionFactory; +import org.apache.fluss.client.admin.Admin; +import org.apache.fluss.client.admin.OffsetSpec; +import org.apache.fluss.client.table.Table; +import org.apache.fluss.client.table.scanner.ScanRecord; +import org.apache.fluss.client.table.scanner.log.LogScanner; +import org.apache.fluss.client.table.scanner.log.ScanRecords; +import org.apache.fluss.client.table.writer.AppendWriter; +import org.apache.fluss.config.ConfigOptions; +import org.apache.fluss.config.Configuration; +import org.apache.fluss.metadata.DatabaseDescriptor; +import org.apache.fluss.metadata.Schema; +import org.apache.fluss.metadata.TableBucket; +import org.apache.fluss.metadata.TableDescriptor; +import org.apache.fluss.metadata.TablePath; +import org.apache.fluss.row.BinaryString; +import org.apache.fluss.row.GenericRow; +import org.apache.fluss.row.InternalRow; +import org.apache.fluss.types.DataTypes; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.time.Duration; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.apache.flink.agents.api.configuration.AgentConfigOptions.FLUSS_ACTION_STATE_DATABASE; +import static org.apache.flink.agents.api.configuration.AgentConfigOptions.FLUSS_ACTION_STATE_TABLE; +import static org.apache.flink.agents.api.configuration.AgentConfigOptions.FLUSS_ACTION_STATE_TABLE_BUCKETS; +import static org.apache.flink.agents.api.configuration.AgentConfigOptions.FLUSS_BOOTSTRAP_SERVERS; +import static org.apache.flink.agents.api.configuration.AgentConfigOptions.FLUSS_SASL_JAAS_CONFIG; +import static org.apache.flink.agents.api.configuration.AgentConfigOptions.FLUSS_SASL_MECHANISM; +import static org.apache.flink.agents.api.configuration.AgentConfigOptions.FLUSS_SASL_PASSWORD; +import static org.apache.flink.agents.api.configuration.AgentConfigOptions.FLUSS_SASL_USERNAME; +import static org.apache.flink.agents.api.configuration.AgentConfigOptions.FLUSS_SECURITY_PROTOCOL; +import static org.apache.flink.agents.runtime.actionstate.ActionStateUtil.generateKey; +import static org.apache.fluss.config.ConfigOptions.BOOTSTRAP_SERVERS; +import static org.apache.fluss.config.ConfigOptions.CLIENT_SASL_JAAS_CONFIG; +import static org.apache.fluss.config.ConfigOptions.CLIENT_SASL_JAAS_PASSWORD; +import static org.apache.fluss.config.ConfigOptions.CLIENT_SASL_JAAS_USERNAME; +import static org.apache.fluss.config.ConfigOptions.CLIENT_SASL_MECHANISM; +import static org.apache.fluss.config.ConfigOptions.CLIENT_SECURITY_PROTOCOL; + +/** + * An implementation of {@link ActionStateStore} that uses an Apache Fluss log table as the backend. + * All state is maintained in an in-memory map for fast lookups, with the Fluss log table providing + * durability and recovery support. + */ +public class FlussActionStateStore implements ActionStateStore { + + private static final Logger LOG = LoggerFactory.getLogger(FlussActionStateStore.class); + + private static final Duration POLL_TIMEOUT = Duration.ofSeconds(1); + + // Column names in the Fluss table schema + private static final String COL_NAME_STATE_KEY = "state_key"; + private static final String COL_NAME_STATE_PAYLOAD = "state_payload"; + private static final String COL_NAME_AGENT_KEY = "agent_key"; + + // Column indices in the Fluss table schema + private static final int COL_STATE_KEY = 0; + private static final int COL_STATE_PAYLOAD = 1; + + private final AgentConfiguration agentConfiguration; + private final String databaseName; + private final String tableName; + private final TablePath tablePath; + + private final Connection connection; + private final Table table; + private final AppendWriter writer; + + /** In-memory cache for O(1) state lookups; rebuilt from Fluss log on recovery. */ + private final Map<String, ActionState> actionStates; + + @VisibleForTesting + FlussActionStateStore( + Map<String, ActionState> actionStates, + Connection connection, + Table table, + AppendWriter writer) { + this.agentConfiguration = null; + this.databaseName = null; + this.tableName = null; + this.tablePath = null; + this.actionStates = actionStates; + this.connection = connection; + this.table = table; + this.writer = writer; + } + + public FlussActionStateStore(AgentConfiguration agentConfiguration) { + this.agentConfiguration = agentConfiguration; + this.databaseName = agentConfiguration.get(FLUSS_ACTION_STATE_DATABASE); + this.tableName = agentConfiguration.get(FLUSS_ACTION_STATE_TABLE); + this.tablePath = TablePath.of(databaseName, tableName); + this.actionStates = new HashMap<>(); + + Configuration flussConf = new Configuration(); + flussConf.setString( + BOOTSTRAP_SERVERS.key(), agentConfiguration.get(FLUSS_BOOTSTRAP_SERVERS)); + // Minimize latency for synchronous put(): setting batch linger time to zero ensures + // that each append is sent immediately without waiting for additional records to batch. + flussConf.set(ConfigOptions.CLIENT_WRITER_BATCH_TIMEOUT, Duration.ZERO); + + flussConf.setString( + CLIENT_SECURITY_PROTOCOL, agentConfiguration.get(FLUSS_SECURITY_PROTOCOL)); + + flussConf.setString(CLIENT_SASL_MECHANISM, agentConfiguration.get(FLUSS_SASL_MECHANISM)); + + String jaasConfig = agentConfiguration.get(FLUSS_SASL_JAAS_CONFIG); + if (jaasConfig != null) { + flussConf.setString(CLIENT_SASL_JAAS_CONFIG, jaasConfig); + } + String username = agentConfiguration.get(FLUSS_SASL_USERNAME); + if (username != null) { + flussConf.setString(CLIENT_SASL_JAAS_USERNAME, username); + } + String password = agentConfiguration.get(FLUSS_SASL_PASSWORD); + if (password != null) { + flussConf.setString(CLIENT_SASL_JAAS_PASSWORD, password); + } + + this.connection = ConnectionFactory.createConnection(flussConf); + maybeCreateDatabaseAndTable(); + this.table = connection.getTable(tablePath); + this.writer = table.newAppend().createWriter(); + + LOG.info( + "Initialized FlussActionStateStore (log table) with table: {}.{}", + databaseName, + tableName); + } + + @Override + public void put(Object key, long seqNum, Action action, Event event, ActionState state) + throws Exception { + String stateKey = generateKey(key, seqNum, action, event); + byte[] payload = ActionStateSerde.serialize(state); + + GenericRow row = + GenericRow.of( + BinaryString.fromString(stateKey), + payload, + BinaryString.fromString(key.toString())); + + // Append to Fluss log for durability, then update in-memory cache. + // Synchronous write ensures the record is durable before returning. + writer.append(row).get(); + actionStates.put(stateKey, state); + + LOG.debug("Stored action state: key={}, isCompleted={}", stateKey, state.isCompleted()); + } + + @Override + public ActionState get(Object key, long seqNum, Action action, Event event) throws Exception { + String stateKey = generateKey(key, seqNum, action, event); + + boolean hasDivergence = checkDivergence(key.toString(), seqNum); + + if (!actionStates.containsKey(stateKey) || hasDivergence) { + actionStates + .entrySet() + .removeIf( + entry -> { + try { + List<String> parts = ActionStateUtil.parseKey(entry.getKey()); + if (parts.size() >= 2) { + long stateSeqNum = Long.parseLong(parts.get(1)); + return stateSeqNum > seqNum; + } + } catch (NumberFormatException e) { + LOG.warn( + "Failed to parse sequence number from state key: {}", + stateKey); + } + return false; + }); + } + + ActionState state = actionStates.get(stateKey); + LOG.debug("Lookup action state: key={}, found={}", stateKey, state != null); + return state; + } + + private boolean checkDivergence(String key, long seqNum) { + return actionStates.keySet().stream() + .filter(k -> k.startsWith(key + "_" + seqNum + "_")) + .count() + > 1; + } + + /** + * Rebuilds in-memory state by scanning the Fluss log table. If recovery markers are provided, + * computes the minimum offset per bucket across all markers and subscribes from those offsets. + * Otherwise, skips rebuild since there is no checkpointed position to recover from. Reads from + * the start offset up to the latest offset captured at rebuild start. For the same state key + * appearing multiple times in the log, the latest record wins (last-write-wins). + */ + @Override + public void rebuildState(List<Object> recoveryMarkers) { + LOG.info( + "Rebuilding action state from Fluss log table with {} recovery markers", + recoveryMarkers.size()); + + if (recoveryMarkers.isEmpty()) { + LOG.info("No recovery markers, skipping state rebuild"); + return; + } + + actionStates.clear(); + + // Compute per-bucket start offsets from recovery markers + Map<Integer, Long> bucketStartOffsets = new HashMap<>(); + for (Object marker : recoveryMarkers) { + if (marker instanceof Map) { + @SuppressWarnings("unchecked") + Map<Integer, Long> markerMap = (Map<Integer, Long>) marker; + for (Map.Entry<Integer, Long> entry : markerMap.entrySet()) { + bucketStartOffsets.merge(entry.getKey(), entry.getValue(), Math::min); + } + } else if (marker != null) { + LOG.warn( + "Ignoring unrecognized recovery marker type: {}", + marker.getClass().getName()); + } + } + + if (bucketStartOffsets.isEmpty()) { + LOG.info("No valid bucket offsets in recovery markers, skipping state rebuild"); + return; + } + + // Capture the latest offsets as the stopping point for each bucket + Map<Integer, Long> bucketEndOffsets = getBucketEndOffsets(); + // Capture the earliest available offsets so we can skip buckets whose data has + // been fully cleaned by log retention (earliestOffset >= endOffset). + Map<Integer, Long> bucketEarliestOffsets = getBucketEarliestOffsets(); + LOG.debug( + "Rebuild window: startOffsets={}, earliestOffsets={}, endOffsets={}", + bucketStartOffsets, + bucketEarliestOffsets, + bucketEndOffsets); + + try (LogScanner scanner = table.newScan().createLogScanner()) { + // Track which buckets still need to be consumed + Map<Integer, Long> remainingBuckets = new HashMap<>(); + for (Map.Entry<Integer, Long> entry : bucketStartOffsets.entrySet()) { + int bucket = entry.getKey(); + long startOffset = entry.getValue(); + Long endOffset = bucketEndOffsets.get(bucket); + + // Bucket referenced in recovery marker does not exist in current table + if (endOffset == null) { + LOG.warn( + "Bucket {} referenced in recovery marker does not exist " + + "in current table, state recovery for this bucket is skipped", + bucket); + continue; + } + + // No new data since checkpoint (includes empty buckets that never had writes) + if (endOffset <= startOffset) { + LOG.info( + "Skipping bucket {} for rebuild: no new data " + + "(endOffset={} <= startOffset={})", + bucket, + endOffset, + startOffset); + continue; + } + + // Check if retention has cleaned data in the recovery window + Long earliestOffset = bucketEarliestOffsets.get(bucket); + long effectiveStart = startOffset; + if (earliestOffset != null && earliestOffset > startOffset) { + effectiveStart = earliestOffset; + if (effectiveStart >= endOffset) { + // All data in recovery window has been cleaned by retention + LOG.warn( + "Bucket {} state recovery failed: all data between offset {} " + + "and {} has been cleaned by log retention " + + "(earliest available: {})", + bucket, + startOffset, + endOffset, + earliestOffset); + continue; + } + // Partial data loss: some state between startOffset and earliestOffset is gone + LOG.warn( + "Bucket {} partial state loss: data between offset {} and {} " + + "has been cleaned by log retention, " + + "recovering from offset {} instead", + bucket, + startOffset, + earliestOffset, + effectiveStart); + } + + scanner.subscribe(bucket, effectiveStart); + remainingBuckets.put(bucket, endOffset); + } + LOG.debug("Subscribed buckets for rebuild: {}", remainingBuckets); + + while (!remainingBuckets.isEmpty()) { + + ScanRecords records = scanner.poll(POLL_TIMEOUT); + for (TableBucket bucket : records.buckets()) { + Long endOffset = remainingBuckets.get(bucket.getBucket()); + if (endOffset == null) { + continue; + } + // Track the highest offset seen in this batch (including skipped records + // beyond endOffset) so that we can detect when the bucket has been fully + // consumed even if the last records are past our target window. + long lastSeenOffset = -1; + for (ScanRecord record : records.records(bucket)) { + lastSeenOffset = record.logOffset(); + if (record.logOffset() >= endOffset) { + continue; + } + InternalRow row = record.getRow(); + String stateKey = row.getString(COL_STATE_KEY).toString(); + byte[] payload = row.getBytes(COL_STATE_PAYLOAD); + try { + ActionState state = ActionStateSerde.deserialize(payload); + actionStates.put(stateKey, state); + } catch (Exception e) { + LOG.warn( + "Failed to deserialize action state for key: {}, skipping", + stateKey, + e); + } + } + // Remove bucket if the highest seen offset has reached or passed the end + if (lastSeenOffset + 1 >= endOffset) { Review Comment: When a bucket reaches its `endOffset`, should we also unsubscribe it from the scanner? Removing it from `remainingBuckets` prevents further replay, but the scanner is still subscribed to that bucket and may keep fetching records beyond the recovery window while other buckets are still being rebuilt. Calling `scanner.unsubscribe(bucket.getBucket())` here would make the bounded-replay boundary explicit and avoid unnecessary fetches for completed buckets. ########## runtime/src/main/java/org/apache/flink/agents/runtime/actionstate/FlussActionStateStore.java: ########## @@ -0,0 +1,489 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.agents.runtime.actionstate; + +import org.apache.flink.agents.api.Event; +import org.apache.flink.agents.plan.AgentConfiguration; +import org.apache.flink.agents.plan.actions.Action; +import org.apache.flink.annotation.VisibleForTesting; +import org.apache.fluss.client.Connection; +import org.apache.fluss.client.ConnectionFactory; +import org.apache.fluss.client.admin.Admin; +import org.apache.fluss.client.admin.OffsetSpec; +import org.apache.fluss.client.table.Table; +import org.apache.fluss.client.table.scanner.ScanRecord; +import org.apache.fluss.client.table.scanner.log.LogScanner; +import org.apache.fluss.client.table.scanner.log.ScanRecords; +import org.apache.fluss.client.table.writer.AppendWriter; +import org.apache.fluss.config.ConfigOptions; +import org.apache.fluss.config.Configuration; +import org.apache.fluss.metadata.DatabaseDescriptor; +import org.apache.fluss.metadata.Schema; +import org.apache.fluss.metadata.TableBucket; +import org.apache.fluss.metadata.TableDescriptor; +import org.apache.fluss.metadata.TablePath; +import org.apache.fluss.row.BinaryString; +import org.apache.fluss.row.GenericRow; +import org.apache.fluss.row.InternalRow; +import org.apache.fluss.types.DataTypes; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.time.Duration; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.apache.flink.agents.api.configuration.AgentConfigOptions.FLUSS_ACTION_STATE_DATABASE; +import static org.apache.flink.agents.api.configuration.AgentConfigOptions.FLUSS_ACTION_STATE_TABLE; +import static org.apache.flink.agents.api.configuration.AgentConfigOptions.FLUSS_ACTION_STATE_TABLE_BUCKETS; +import static org.apache.flink.agents.api.configuration.AgentConfigOptions.FLUSS_BOOTSTRAP_SERVERS; +import static org.apache.flink.agents.api.configuration.AgentConfigOptions.FLUSS_SASL_JAAS_CONFIG; +import static org.apache.flink.agents.api.configuration.AgentConfigOptions.FLUSS_SASL_MECHANISM; +import static org.apache.flink.agents.api.configuration.AgentConfigOptions.FLUSS_SASL_PASSWORD; +import static org.apache.flink.agents.api.configuration.AgentConfigOptions.FLUSS_SASL_USERNAME; +import static org.apache.flink.agents.api.configuration.AgentConfigOptions.FLUSS_SECURITY_PROTOCOL; +import static org.apache.flink.agents.runtime.actionstate.ActionStateUtil.generateKey; +import static org.apache.fluss.config.ConfigOptions.BOOTSTRAP_SERVERS; +import static org.apache.fluss.config.ConfigOptions.CLIENT_SASL_JAAS_CONFIG; +import static org.apache.fluss.config.ConfigOptions.CLIENT_SASL_JAAS_PASSWORD; +import static org.apache.fluss.config.ConfigOptions.CLIENT_SASL_JAAS_USERNAME; +import static org.apache.fluss.config.ConfigOptions.CLIENT_SASL_MECHANISM; +import static org.apache.fluss.config.ConfigOptions.CLIENT_SECURITY_PROTOCOL; + +/** + * An implementation of {@link ActionStateStore} that uses an Apache Fluss log table as the backend. + * All state is maintained in an in-memory map for fast lookups, with the Fluss log table providing + * durability and recovery support. + */ +public class FlussActionStateStore implements ActionStateStore { + + private static final Logger LOG = LoggerFactory.getLogger(FlussActionStateStore.class); + + private static final Duration POLL_TIMEOUT = Duration.ofSeconds(1); + + // Column names in the Fluss table schema + private static final String COL_NAME_STATE_KEY = "state_key"; + private static final String COL_NAME_STATE_PAYLOAD = "state_payload"; + private static final String COL_NAME_AGENT_KEY = "agent_key"; + + // Column indices in the Fluss table schema + private static final int COL_STATE_KEY = 0; + private static final int COL_STATE_PAYLOAD = 1; + + private final AgentConfiguration agentConfiguration; + private final String databaseName; + private final String tableName; + private final TablePath tablePath; + + private final Connection connection; + private final Table table; + private final AppendWriter writer; + + /** In-memory cache for O(1) state lookups; rebuilt from Fluss log on recovery. */ + private final Map<String, ActionState> actionStates; + + @VisibleForTesting + FlussActionStateStore( + Map<String, ActionState> actionStates, + Connection connection, + Table table, + AppendWriter writer) { + this.agentConfiguration = null; + this.databaseName = null; + this.tableName = null; + this.tablePath = null; + this.actionStates = actionStates; + this.connection = connection; + this.table = table; + this.writer = writer; + } + + public FlussActionStateStore(AgentConfiguration agentConfiguration) { + this.agentConfiguration = agentConfiguration; + this.databaseName = agentConfiguration.get(FLUSS_ACTION_STATE_DATABASE); + this.tableName = agentConfiguration.get(FLUSS_ACTION_STATE_TABLE); + this.tablePath = TablePath.of(databaseName, tableName); + this.actionStates = new HashMap<>(); + + Configuration flussConf = new Configuration(); + flussConf.setString( + BOOTSTRAP_SERVERS.key(), agentConfiguration.get(FLUSS_BOOTSTRAP_SERVERS)); + // Minimize latency for synchronous put(): setting batch linger time to zero ensures + // that each append is sent immediately without waiting for additional records to batch. + flussConf.set(ConfigOptions.CLIENT_WRITER_BATCH_TIMEOUT, Duration.ZERO); + + flussConf.setString( + CLIENT_SECURITY_PROTOCOL, agentConfiguration.get(FLUSS_SECURITY_PROTOCOL)); + + flussConf.setString(CLIENT_SASL_MECHANISM, agentConfiguration.get(FLUSS_SASL_MECHANISM)); + + String jaasConfig = agentConfiguration.get(FLUSS_SASL_JAAS_CONFIG); + if (jaasConfig != null) { + flussConf.setString(CLIENT_SASL_JAAS_CONFIG, jaasConfig); + } + String username = agentConfiguration.get(FLUSS_SASL_USERNAME); + if (username != null) { + flussConf.setString(CLIENT_SASL_JAAS_USERNAME, username); + } + String password = agentConfiguration.get(FLUSS_SASL_PASSWORD); + if (password != null) { + flussConf.setString(CLIENT_SASL_JAAS_PASSWORD, password); + } + + this.connection = ConnectionFactory.createConnection(flussConf); + maybeCreateDatabaseAndTable(); + this.table = connection.getTable(tablePath); + this.writer = table.newAppend().createWriter(); + + LOG.info( + "Initialized FlussActionStateStore (log table) with table: {}.{}", + databaseName, + tableName); + } + + @Override + public void put(Object key, long seqNum, Action action, Event event, ActionState state) + throws Exception { + String stateKey = generateKey(key, seqNum, action, event); + byte[] payload = ActionStateSerde.serialize(state); + + GenericRow row = + GenericRow.of( + BinaryString.fromString(stateKey), + payload, + BinaryString.fromString(key.toString())); + + // Append to Fluss log for durability, then update in-memory cache. + // Synchronous write ensures the record is durable before returning. + writer.append(row).get(); + actionStates.put(stateKey, state); + + LOG.debug("Stored action state: key={}, isCompleted={}", stateKey, state.isCompleted()); + } + + @Override + public ActionState get(Object key, long seqNum, Action action, Event event) throws Exception { + String stateKey = generateKey(key, seqNum, action, event); + + boolean hasDivergence = checkDivergence(key.toString(), seqNum); + + if (!actionStates.containsKey(stateKey) || hasDivergence) { + actionStates + .entrySet() + .removeIf( + entry -> { + try { + List<String> parts = ActionStateUtil.parseKey(entry.getKey()); + if (parts.size() >= 2) { + long stateSeqNum = Long.parseLong(parts.get(1)); + return stateSeqNum > seqNum; + } + } catch (NumberFormatException e) { + LOG.warn( + "Failed to parse sequence number from state key: {}", + stateKey); + } + return false; + }); + } + + ActionState state = actionStates.get(stateKey); + LOG.debug("Lookup action state: key={}, found={}", stateKey, state != null); + return state; + } + + private boolean checkDivergence(String key, long seqNum) { + return actionStates.keySet().stream() + .filter(k -> k.startsWith(key + "_" + seqNum + "_")) + .count() + > 1; + } + + /** + * Rebuilds in-memory state by scanning the Fluss log table. If recovery markers are provided, + * computes the minimum offset per bucket across all markers and subscribes from those offsets. + * Otherwise, skips rebuild since there is no checkpointed position to recover from. Reads from + * the start offset up to the latest offset captured at rebuild start. For the same state key + * appearing multiple times in the log, the latest record wins (last-write-wins). + */ + @Override + public void rebuildState(List<Object> recoveryMarkers) { + LOG.info( + "Rebuilding action state from Fluss log table with {} recovery markers", + recoveryMarkers.size()); + + if (recoveryMarkers.isEmpty()) { + LOG.info("No recovery markers, skipping state rebuild"); + return; + } + + actionStates.clear(); + + // Compute per-bucket start offsets from recovery markers + Map<Integer, Long> bucketStartOffsets = new HashMap<>(); + for (Object marker : recoveryMarkers) { + if (marker instanceof Map) { + @SuppressWarnings("unchecked") + Map<Integer, Long> markerMap = (Map<Integer, Long>) marker; + for (Map.Entry<Integer, Long> entry : markerMap.entrySet()) { + bucketStartOffsets.merge(entry.getKey(), entry.getValue(), Math::min); + } + } else if (marker != null) { + LOG.warn( + "Ignoring unrecognized recovery marker type: {}", + marker.getClass().getName()); + } + } + + if (bucketStartOffsets.isEmpty()) { + LOG.info("No valid bucket offsets in recovery markers, skipping state rebuild"); + return; + } + + // Capture the latest offsets as the stopping point for each bucket + Map<Integer, Long> bucketEndOffsets = getBucketEndOffsets(); + // Capture the earliest available offsets so we can skip buckets whose data has + // been fully cleaned by log retention (earliestOffset >= endOffset). + Map<Integer, Long> bucketEarliestOffsets = getBucketEarliestOffsets(); + LOG.debug( + "Rebuild window: startOffsets={}, earliestOffsets={}, endOffsets={}", + bucketStartOffsets, + bucketEarliestOffsets, + bucketEndOffsets); + + try (LogScanner scanner = table.newScan().createLogScanner()) { + // Track which buckets still need to be consumed + Map<Integer, Long> remainingBuckets = new HashMap<>(); + for (Map.Entry<Integer, Long> entry : bucketStartOffsets.entrySet()) { + int bucket = entry.getKey(); + long startOffset = entry.getValue(); + Long endOffset = bucketEndOffsets.get(bucket); + + // Bucket referenced in recovery marker does not exist in current table + if (endOffset == null) { Review Comment: Could we fail recovery instead of skipping buckets when the recovery window is invalid or incomplete? Cases like 1. `endOffset == null`, 2. `endOffset < startOffset`, 3. `earliestOffset > startOffset` all indicate that the marker no longer matches the available log range. Treating them as no-op or partially rebuilding from `earliestOffset` can silently drop action state from the checkpoint recovery window. Since this store is on the correctness path, I think these cases should fail fast rather than continue with an incomplete cache. ########## runtime/src/test/java/org/apache/flink/agents/runtime/actionstate/TestAction.java: ########## @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.agents.runtime.actionstate; + +import org.apache.flink.agents.api.Event; +import org.apache.flink.agents.api.InputEvent; +import org.apache.flink.agents.api.context.RunnerContext; +import org.apache.flink.agents.plan.JavaFunction; +import org.apache.flink.agents.plan.actions.Action; + +import java.util.List; + +/** A no-op {@link Action} for use in action state store tests. */ +public class TestAction extends Action { Review Comment: ```suggestion public class NoOpAction extends Action { ``` ########## runtime/src/main/java/org/apache/flink/agents/runtime/actionstate/FlussActionStateStore.java: ########## @@ -0,0 +1,489 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.agents.runtime.actionstate; + +import org.apache.flink.agents.api.Event; +import org.apache.flink.agents.plan.AgentConfiguration; +import org.apache.flink.agents.plan.actions.Action; +import org.apache.flink.annotation.VisibleForTesting; +import org.apache.fluss.client.Connection; +import org.apache.fluss.client.ConnectionFactory; +import org.apache.fluss.client.admin.Admin; +import org.apache.fluss.client.admin.OffsetSpec; +import org.apache.fluss.client.table.Table; +import org.apache.fluss.client.table.scanner.ScanRecord; +import org.apache.fluss.client.table.scanner.log.LogScanner; +import org.apache.fluss.client.table.scanner.log.ScanRecords; +import org.apache.fluss.client.table.writer.AppendWriter; +import org.apache.fluss.config.ConfigOptions; +import org.apache.fluss.config.Configuration; +import org.apache.fluss.metadata.DatabaseDescriptor; +import org.apache.fluss.metadata.Schema; +import org.apache.fluss.metadata.TableBucket; +import org.apache.fluss.metadata.TableDescriptor; +import org.apache.fluss.metadata.TablePath; +import org.apache.fluss.row.BinaryString; +import org.apache.fluss.row.GenericRow; +import org.apache.fluss.row.InternalRow; +import org.apache.fluss.types.DataTypes; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.time.Duration; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.apache.flink.agents.api.configuration.AgentConfigOptions.FLUSS_ACTION_STATE_DATABASE; +import static org.apache.flink.agents.api.configuration.AgentConfigOptions.FLUSS_ACTION_STATE_TABLE; +import static org.apache.flink.agents.api.configuration.AgentConfigOptions.FLUSS_ACTION_STATE_TABLE_BUCKETS; +import static org.apache.flink.agents.api.configuration.AgentConfigOptions.FLUSS_BOOTSTRAP_SERVERS; +import static org.apache.flink.agents.api.configuration.AgentConfigOptions.FLUSS_SASL_JAAS_CONFIG; +import static org.apache.flink.agents.api.configuration.AgentConfigOptions.FLUSS_SASL_MECHANISM; +import static org.apache.flink.agents.api.configuration.AgentConfigOptions.FLUSS_SASL_PASSWORD; +import static org.apache.flink.agents.api.configuration.AgentConfigOptions.FLUSS_SASL_USERNAME; +import static org.apache.flink.agents.api.configuration.AgentConfigOptions.FLUSS_SECURITY_PROTOCOL; +import static org.apache.flink.agents.runtime.actionstate.ActionStateUtil.generateKey; +import static org.apache.fluss.config.ConfigOptions.BOOTSTRAP_SERVERS; +import static org.apache.fluss.config.ConfigOptions.CLIENT_SASL_JAAS_CONFIG; +import static org.apache.fluss.config.ConfigOptions.CLIENT_SASL_JAAS_PASSWORD; +import static org.apache.fluss.config.ConfigOptions.CLIENT_SASL_JAAS_USERNAME; +import static org.apache.fluss.config.ConfigOptions.CLIENT_SASL_MECHANISM; +import static org.apache.fluss.config.ConfigOptions.CLIENT_SECURITY_PROTOCOL; + +/** + * An implementation of {@link ActionStateStore} that uses an Apache Fluss log table as the backend. + * All state is maintained in an in-memory map for fast lookups, with the Fluss log table providing + * durability and recovery support. + */ +public class FlussActionStateStore implements ActionStateStore { + + private static final Logger LOG = LoggerFactory.getLogger(FlussActionStateStore.class); + + private static final Duration POLL_TIMEOUT = Duration.ofSeconds(1); + + // Column names in the Fluss table schema + private static final String COL_NAME_STATE_KEY = "state_key"; + private static final String COL_NAME_STATE_PAYLOAD = "state_payload"; + private static final String COL_NAME_AGENT_KEY = "agent_key"; + + // Column indices in the Fluss table schema + private static final int COL_STATE_KEY = 0; + private static final int COL_STATE_PAYLOAD = 1; + + private final AgentConfiguration agentConfiguration; + private final String databaseName; + private final String tableName; + private final TablePath tablePath; + + private final Connection connection; + private final Table table; + private final AppendWriter writer; + + /** In-memory cache for O(1) state lookups; rebuilt from Fluss log on recovery. */ + private final Map<String, ActionState> actionStates; + + @VisibleForTesting + FlussActionStateStore( + Map<String, ActionState> actionStates, + Connection connection, + Table table, + AppendWriter writer) { + this.agentConfiguration = null; + this.databaseName = null; + this.tableName = null; + this.tablePath = null; + this.actionStates = actionStates; + this.connection = connection; + this.table = table; + this.writer = writer; + } + + public FlussActionStateStore(AgentConfiguration agentConfiguration) { + this.agentConfiguration = agentConfiguration; + this.databaseName = agentConfiguration.get(FLUSS_ACTION_STATE_DATABASE); + this.tableName = agentConfiguration.get(FLUSS_ACTION_STATE_TABLE); + this.tablePath = TablePath.of(databaseName, tableName); + this.actionStates = new HashMap<>(); + + Configuration flussConf = new Configuration(); + flussConf.setString( + BOOTSTRAP_SERVERS.key(), agentConfiguration.get(FLUSS_BOOTSTRAP_SERVERS)); + // Minimize latency for synchronous put(): setting batch linger time to zero ensures + // that each append is sent immediately without waiting for additional records to batch. + flussConf.set(ConfigOptions.CLIENT_WRITER_BATCH_TIMEOUT, Duration.ZERO); + + flussConf.setString( + CLIENT_SECURITY_PROTOCOL, agentConfiguration.get(FLUSS_SECURITY_PROTOCOL)); + + flussConf.setString(CLIENT_SASL_MECHANISM, agentConfiguration.get(FLUSS_SASL_MECHANISM)); + + String jaasConfig = agentConfiguration.get(FLUSS_SASL_JAAS_CONFIG); + if (jaasConfig != null) { + flussConf.setString(CLIENT_SASL_JAAS_CONFIG, jaasConfig); + } + String username = agentConfiguration.get(FLUSS_SASL_USERNAME); + if (username != null) { + flussConf.setString(CLIENT_SASL_JAAS_USERNAME, username); + } + String password = agentConfiguration.get(FLUSS_SASL_PASSWORD); + if (password != null) { + flussConf.setString(CLIENT_SASL_JAAS_PASSWORD, password); + } + + this.connection = ConnectionFactory.createConnection(flussConf); + maybeCreateDatabaseAndTable(); + this.table = connection.getTable(tablePath); + this.writer = table.newAppend().createWriter(); + + LOG.info( + "Initialized FlussActionStateStore (log table) with table: {}.{}", + databaseName, + tableName); + } + + @Override + public void put(Object key, long seqNum, Action action, Event event, ActionState state) + throws Exception { + String stateKey = generateKey(key, seqNum, action, event); + byte[] payload = ActionStateSerde.serialize(state); + + GenericRow row = + GenericRow.of( + BinaryString.fromString(stateKey), + payload, + BinaryString.fromString(key.toString())); + + // Append to Fluss log for durability, then update in-memory cache. + // Synchronous write ensures the record is durable before returning. + writer.append(row).get(); + actionStates.put(stateKey, state); + + LOG.debug("Stored action state: key={}, isCompleted={}", stateKey, state.isCompleted()); + } + + @Override + public ActionState get(Object key, long seqNum, Action action, Event event) throws Exception { + String stateKey = generateKey(key, seqNum, action, event); + + boolean hasDivergence = checkDivergence(key.toString(), seqNum); + + if (!actionStates.containsKey(stateKey) || hasDivergence) { + actionStates + .entrySet() + .removeIf( + entry -> { + try { + List<String> parts = ActionStateUtil.parseKey(entry.getKey()); + if (parts.size() >= 2) { + long stateSeqNum = Long.parseLong(parts.get(1)); + return stateSeqNum > seqNum; + } + } catch (NumberFormatException e) { + LOG.warn( + "Failed to parse sequence number from state key: {}", + stateKey); + } + return false; + }); + } + + ActionState state = actionStates.get(stateKey); + LOG.debug("Lookup action state: key={}, found={}", stateKey, state != null); + return state; + } + + private boolean checkDivergence(String key, long seqNum) { + return actionStates.keySet().stream() + .filter(k -> k.startsWith(key + "_" + seqNum + "_")) + .count() + > 1; + } + + /** + * Rebuilds in-memory state by scanning the Fluss log table. If recovery markers are provided, + * computes the minimum offset per bucket across all markers and subscribes from those offsets. + * Otherwise, skips rebuild since there is no checkpointed position to recover from. Reads from + * the start offset up to the latest offset captured at rebuild start. For the same state key + * appearing multiple times in the log, the latest record wins (last-write-wins). + */ + @Override + public void rebuildState(List<Object> recoveryMarkers) { + LOG.info( + "Rebuilding action state from Fluss log table with {} recovery markers", + recoveryMarkers.size()); + + if (recoveryMarkers.isEmpty()) { + LOG.info("No recovery markers, skipping state rebuild"); + return; + } + + actionStates.clear(); + + // Compute per-bucket start offsets from recovery markers + Map<Integer, Long> bucketStartOffsets = new HashMap<>(); + for (Object marker : recoveryMarkers) { + if (marker instanceof Map) { + @SuppressWarnings("unchecked") + Map<Integer, Long> markerMap = (Map<Integer, Long>) marker; + for (Map.Entry<Integer, Long> entry : markerMap.entrySet()) { + bucketStartOffsets.merge(entry.getKey(), entry.getValue(), Math::min); + } + } else if (marker != null) { + LOG.warn( + "Ignoring unrecognized recovery marker type: {}", + marker.getClass().getName()); + } + } + + if (bucketStartOffsets.isEmpty()) { + LOG.info("No valid bucket offsets in recovery markers, skipping state rebuild"); + return; + } + + // Capture the latest offsets as the stopping point for each bucket + Map<Integer, Long> bucketEndOffsets = getBucketEndOffsets(); + // Capture the earliest available offsets so we can skip buckets whose data has + // been fully cleaned by log retention (earliestOffset >= endOffset). + Map<Integer, Long> bucketEarliestOffsets = getBucketEarliestOffsets(); + LOG.debug( + "Rebuild window: startOffsets={}, earliestOffsets={}, endOffsets={}", + bucketStartOffsets, + bucketEarliestOffsets, + bucketEndOffsets); + + try (LogScanner scanner = table.newScan().createLogScanner()) { + // Track which buckets still need to be consumed + Map<Integer, Long> remainingBuckets = new HashMap<>(); + for (Map.Entry<Integer, Long> entry : bucketStartOffsets.entrySet()) { + int bucket = entry.getKey(); + long startOffset = entry.getValue(); + Long endOffset = bucketEndOffsets.get(bucket); + + // Bucket referenced in recovery marker does not exist in current table + if (endOffset == null) { + LOG.warn( + "Bucket {} referenced in recovery marker does not exist " + + "in current table, state recovery for this bucket is skipped", + bucket); + continue; + } + + // No new data since checkpoint (includes empty buckets that never had writes) + if (endOffset <= startOffset) { + LOG.info( + "Skipping bucket {} for rebuild: no new data " + + "(endOffset={} <= startOffset={})", + bucket, + endOffset, + startOffset); + continue; + } + + // Check if retention has cleaned data in the recovery window + Long earliestOffset = bucketEarliestOffsets.get(bucket); + long effectiveStart = startOffset; + if (earliestOffset != null && earliestOffset > startOffset) { + effectiveStart = earliestOffset; + if (effectiveStart >= endOffset) { + // All data in recovery window has been cleaned by retention + LOG.warn( + "Bucket {} state recovery failed: all data between offset {} " + + "and {} has been cleaned by log retention " + + "(earliest available: {})", + bucket, + startOffset, + endOffset, + earliestOffset); + continue; + } + // Partial data loss: some state between startOffset and earliestOffset is gone + LOG.warn( + "Bucket {} partial state loss: data between offset {} and {} " + + "has been cleaned by log retention, " + + "recovering from offset {} instead", + bucket, + startOffset, + earliestOffset, + effectiveStart); + } + + scanner.subscribe(bucket, effectiveStart); + remainingBuckets.put(bucket, endOffset); + } + LOG.debug("Subscribed buckets for rebuild: {}", remainingBuckets); + + while (!remainingBuckets.isEmpty()) { + + ScanRecords records = scanner.poll(POLL_TIMEOUT); + for (TableBucket bucket : records.buckets()) { + Long endOffset = remainingBuckets.get(bucket.getBucket()); + if (endOffset == null) { + continue; + } + // Track the highest offset seen in this batch (including skipped records + // beyond endOffset) so that we can detect when the bucket has been fully + // consumed even if the last records are past our target window. + long lastSeenOffset = -1; + for (ScanRecord record : records.records(bucket)) { + lastSeenOffset = record.logOffset(); + if (record.logOffset() >= endOffset) { + continue; + } + InternalRow row = record.getRow(); + String stateKey = row.getString(COL_STATE_KEY).toString(); + byte[] payload = row.getBytes(COL_STATE_PAYLOAD); + try { + ActionState state = ActionStateSerde.deserialize(payload); + actionStates.put(stateKey, state); + } catch (Exception e) { + LOG.warn( Review Comment: Should deserialization failures fail the rebuild instead of skipping the record? The Fluss log is the durable source for action state recovery. If a record in the recovery window cannot be restored, continuing with an incomplete cache may change replay behavior and cause actions to run again. I think rebuild should fail loudly on corrupted or incompatible state payloads. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
