This is an automated email from the ASF dual-hosted git repository.
fhueske pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git
The following commit(s) were added to refs/heads/master by this push:
new 59bb8b7a089 [FLINK-39379][table] Add state management to PTF test
harness (#28212)
59bb8b7a089 is described below
commit 59bb8b7a089b22665e06baee15c574541c402383
Author: Mika Naylor <[email protected]>
AuthorDate: Fri May 29 19:04:04 2026 +0200
[FLINK-39379][table] Add state management to PTF test harness (#28212)
Extends the PTF Test Harness to support the use of state (map, list, pojo,
and row).
The harness passes state objects to a PTF via eval() and converts the state
into an internal format (Map/List/RowData) after eval() returned to also
provide testing of state serialization.
The harness can initialize, inspect (get), set, or clear individual states
of a PTF.
This change does not yet add support for state TTL.
---
docs/content.zh/docs/dev/table/functions/ptfs.md | 238 ++++++-
docs/content/docs/dev/table/functions/ptfs.md | 238 ++++++-
.../runtime/functions/ListViewStateConverter.java | 79 +++
.../runtime/functions/MapViewStateConverter.java | 92 +++
.../functions/ProcessTableFunctionTestHarness.java | 701 ++++++++++++++------
.../table/runtime/functions/RowStateConverter.java | 60 ++
.../table/runtime/functions/StateConverter.java | 38 ++
.../functions/StructuredTypeStateConverter.java | 67 ++
.../functions/TestHarnessDataTypeFactory.java | 4 +-
.../runtime/functions/TestHarnessStateManager.java | 233 +++++++
.../ProcessTableFunctionTestHarnessTest.java | 718 ++++++++++++++++++++-
11 files changed, 2230 insertions(+), 238 deletions(-)
diff --git a/docs/content.zh/docs/dev/table/functions/ptfs.md
b/docs/content.zh/docs/dev/table/functions/ptfs.md
index 4190e9e3ba5..bf24d7ced06 100644
--- a/docs/content.zh/docs/dev/table/functions/ptfs.md
+++ b/docs/content.zh/docs/dev/table/functions/ptfs.md
@@ -2275,6 +2275,240 @@ void testScalarOnly() throws Exception {
{{< /tab >}}
{{< /tabs >}}
+#### Testing with State
+
+The harness supports all PTF state types: value state (Pojo and `Row`), list
state (`ListView`),
+and map state (`MapView`).
+
+{{< tabs "state-testing" >}}
+{{< tab "Java" >}}
+```java
+// A PTF that uses all four state types: Pojo value state, Row value state,
ListView state, and MapView state.
+@DataTypeHint("ROW<count BIGINT>")
+public class StatefulPTF extends ProcessTableFunction<Row> {
+ public static class ValueState {
+ public long count = 0L;
+ }
+
+ public void eval(
+ @StateHint ValueState valueState,
+ @StateHint(type = @DataTypeHint("ROW<lastValue INT>")) Row rowState,
+ @StateHint(type = @DataTypeHint("ARRAY<INT>")) ListView<Integer> listState,
+ @StateHint MapView<String, Integer> mapState,
+ @ArgumentHint(ArgumentTrait.SET_SEMANTIC_TABLE) Row input) throws
Exception {
+ // Value state — increment counter
+ valueState.count++;
+
+ // Row state — track the last value seen
+ int value = input.getFieldAs("value");
+ rowState.setField("lastValue", value);
+
+ // ListView state — accumulate values
+ listState.add(value);
+
+ // MapView state — count occurrences by name
+ String name = input.getFieldAs("name");
+ Integer tagCount = mapState.get(name);
+ mapState.put(name, tagCount == null ? 1 : tagCount + 1);
+
+ collect(Row.of(valueState.count));
+ }
+}
+
+@Test
+void testWithState() throws Exception {
+ try (ProcessTableFunctionTestHarness<Row> harness =
+ ProcessTableFunctionTestHarness.ofClass(StatefulPTF.class)
+ .withTableArgument("input", DataTypes.of("ROW<name STRING, value INT>"))
+ .withPartitionBy("input", "name")
+ .build()) {
+
+ harness.processElement(Row.of("Alice", 10));
+ harness.processElement(Row.of("Alice", 20));
+
+ List<Row> output = harness.getOutput();
+ assertThat(output.get(0)).isEqualTo(Row.of("Alice", 1L));
+ assertThat(output.get(1)).isEqualTo(Row.of("Alice", 2L));
+ }
+}
+```
+{{< /tab >}}
+{{< /tabs >}}
+
+**Initial State Setup**: Use `.withInitialStateForKey()` to pre-populate state
before processing.
+State initialization is scoped per partition key:
+
+{{< tabs "initial-state" >}}
+{{< tab "Java" >}}
+```java
+@Test
+void testWithInitialState() throws Exception {
+ // Value state
+ StatefulPTF.ValueState initialValue = new StatefulPTF.ValueState();
+ initialValue.count = 100L;
+
+ // Row state
+ Row initialRow = Row.withNames();
+ initialRow.setField("lastValue", 42);
+
+ // ListView state
+ ListView<Integer> initialList = new ListView<>();
+ initialList.add(10);
+ initialList.add(20);
+
+ // MapView state
+ MapView<String, Integer> initialMap = new MapView<>();
+ initialMap.put("Alice", 5);
+
+ try (ProcessTableFunctionTestHarness<Row> harness =
+ ProcessTableFunctionTestHarness.ofClass(StatefulPTF.class)
+ .withTableArgument("input", DataTypes.of("ROW<name STRING, value INT>"))
+ .withPartitionBy("input", "name")
+ // Initial state is set per partition key
+ .withInitialStateForKey("valueState", Row.of("Alice"), initialValue)
+ .withInitialStateForKey("rowState", Row.of("Alice"), initialRow)
+ .withInitialStateForKey("listState", Row.of("Alice"), initialList)
+ .withInitialStateForKey("mapState", Row.of("Alice"), initialMap)
+ .build()) {
+
+ harness.processElement(Row.of("Alice", 10));
+
+ List<Row> output = harness.getOutput();
+ assertThat(output).containsExactly(Row.of("Alice", 101L));
+ }
+}
+```
+{{< /tab >}}
+{{< /tabs >}}
+
+**State Introspection**: Use `getStateForKey()`, `getKeysForState()`, and
`getStateForAllKeys()` to
+inspect state during tests:
+
+{{< tabs "state-introspection" >}}
+{{< tab "Java" >}}
+```java
+@Test
+void testStateIntrospection() throws Exception {
+ try (ProcessTableFunctionTestHarness<Row> harness =
+ ProcessTableFunctionTestHarness.ofClass(StatefulPTF.class)
+ .withTableArgument("input", DataTypes.of("ROW<name STRING, value INT>"))
+ .withPartitionBy("input", "name")
+ .build()) {
+
+ harness.processElement(Row.of("Alice", 10));
+ harness.processElement(Row.of("Bob", 20));
+
+ // Check value state
+ StatefulPTF.ValueState aliceState =
+ harness.getStateForKey("valueState", Row.of("Alice"));
+ assertThat(aliceState.count).isEqualTo(1L);
+
+ // Check Row state
+ Row aliceRowState = harness.getStateForKey("rowState", Row.of("Alice"));
+ assertThat(aliceRowState.getField("lastValue")).isEqualTo(10);
+
+ // Check ListView state
+ ListView<Integer> aliceList = harness.getStateForKey("listState",
Row.of("Alice"));
+ assertThat(aliceList.getList()).containsExactly(10);
+
+ // Check MapView state
+ MapView<String, Integer> aliceMap = harness.getStateForKey("mapState",
Row.of("Alice"));
+ assertThat(aliceMap.get("Alice")).isEqualTo(1);
+
+ // Get all partition keys with state
+ Set<Row> keys = harness.getKeysForState("valueState");
+ assertThat(keys).containsExactlyInAnyOrder(Row.of("Alice"), Row.of("Bob"));
+
+ // Get all state across partition keys
+ Map<Row, StatefulPTF.ValueState> allState =
+ harness.getStateForAllKeys("valueState");
+ assertThat(allState.get(Row.of("Bob")).count).isEqualTo(1L);
+ }
+}
+```
+{{< /tab >}}
+{{< /tabs >}}
+
+**State Mutation**: Use `setStateForKey()`, `clearAllStatesForKey()`, and
`clearStateForKey()` to
+modify state during tests:
+
+{{< tabs "state-mutation" >}}
+{{< tab "Java" >}}
+```java
+@Test
+void testStateMutation() throws Exception {
+ try (ProcessTableFunctionTestHarness<Row> harness =
+ ProcessTableFunctionTestHarness.ofClass(StatefulPTF.class)
+ .withTableArgument("input", DataTypes.of("ROW<name STRING, value INT>"))
+ .withPartitionBy("input", "name")
+ .build()) {
+
+ harness.processElement(Row.of("Alice", 10));
+
+ // Overwrite a specific state entry for a partition key
+ StatefulPTF.ValueState newState = new StatefulPTF.ValueState();
+ newState.count = 100L;
+ harness.setStateForKey("valueState", Row.of("Alice"), newState);
+
+ // Verify the state was updated
+ StatefulPTF.ValueState state = harness.getStateForKey("valueState",
Row.of("Alice"));
+ assertThat(state.count).isEqualTo(100L);
+
+ // Clear a specific state entry (resets to default)
+ harness.clearStateForKey("listState", Row.of("Alice"));
+
+ // Clear all state for a partition key
+ harness.clearAllStatesForKey(Row.of("Alice"));
+ }
+}
+```
+{{< /tab >}}
+{{< /tabs >}}
+
+#### Optional Partitioning
+
+For PTFs with `OPTIONAL_PARTITION_BY`, you can omit `withPartitionBy()` during
harness setup. The
+harness executes the function as if it had a parallelism of 1, with the
default `Row.of()` key,
+so all data is routed through the same function instance. Use `Row.of()` to
access state:
+
+{{< tabs "optional-partition" >}}
+{{< tab "Java" >}}
+```java
+// A PTF with optional partitioning that counts all rows.
+@DataTypeHint("ROW<count BIGINT>")
+public class GlobalCountPTF extends ProcessTableFunction<Row> {
+ public static class CountState {
+ public long count = 0L;
+ }
+
+ public void eval(
+ @StateHint CountState state,
+ @ArgumentHint({ArgumentTrait.SET_SEMANTIC_TABLE,
ArgumentTrait.OPTIONAL_PARTITION_BY})
+ Row input) {
+ state.count++;
+ collect(Row.of(state.count));
+ }
+}
+
+@Test
+void testOptionalPartitionWithoutPartitionBy() throws Exception {
+ try (ProcessTableFunctionTestHarness<Row> harness =
+ ProcessTableFunctionTestHarness.ofClass(GlobalCountPTF.class)
+ .withTableArgument("input", DataTypes.of("ROW<key STRING, value INT>"))
+ .build()) {
+
+ harness.processElement(Row.of("A", 10));
+ harness.processElement(Row.of("B", 20));
+
+ // All data shares a single Row.of() partition key
+ GlobalCountPTF.CountState state = harness.getStateForKey("state",
Row.of());
+ assertThat(state.count).isEqualTo(2L);
+ }
+}
+```
+{{< /tab >}}
+{{< /tabs >}}
+
#### Configuring Table Argument Types
In contexts where the harness can't infer the table argument types for table
arguments (when using unannotated `Row` inputs,
@@ -2348,8 +2582,8 @@ void testPOJO() throws Exception {
### PTF Features Unsupported by the TestHarness
-- `Context` paramter
-- State (`@StateHint`)
+- `Context` parameter
- Timers (`onTimer`)
- `on_time` / `rowtime`
- Update traits (`SUPPORTS_UPDATES`, `REQUIRE_UPDATE_BEFORE`)
+- State TTL (state is supported but TTL expiration is not yet implemented)
diff --git a/docs/content/docs/dev/table/functions/ptfs.md
b/docs/content/docs/dev/table/functions/ptfs.md
index 7181fc1a3d1..2abf34fd066 100644
--- a/docs/content/docs/dev/table/functions/ptfs.md
+++ b/docs/content/docs/dev/table/functions/ptfs.md
@@ -2278,6 +2278,240 @@ void testScalarOnly() throws Exception {
{{< /tab >}}
{{< /tabs >}}
+#### Testing with State
+
+The harness supports all PTF state types: value state (Pojo and `Row`), list
state (`ListView`),
+and map state (`MapView`).
+
+{{< tabs "state-testing" >}}
+{{< tab "Java" >}}
+```java
+// A PTF that uses all four state types: Pojo value state, Row value state,
ListView state, and MapView state.
+@DataTypeHint("ROW<count BIGINT>")
+public class StatefulPTF extends ProcessTableFunction<Row> {
+ public static class ValueState {
+ public long count = 0L;
+ }
+
+ public void eval(
+ @StateHint ValueState valueState,
+ @StateHint(type = @DataTypeHint("ROW<lastValue INT>")) Row rowState,
+ @StateHint(type = @DataTypeHint("ARRAY<INT>")) ListView<Integer> listState,
+ @StateHint MapView<String, Integer> mapState,
+ @ArgumentHint(ArgumentTrait.SET_SEMANTIC_TABLE) Row input) throws
Exception {
+ // Value state — increment counter
+ valueState.count++;
+
+ // Row state — track the last value seen
+ int value = input.getFieldAs("value");
+ rowState.setField("lastValue", value);
+
+ // ListView state — accumulate values
+ listState.add(value);
+
+ // MapView state — count occurrences by name
+ String name = input.getFieldAs("name");
+ Integer tagCount = mapState.get(name);
+ mapState.put(name, tagCount == null ? 1 : tagCount + 1);
+
+ collect(Row.of(valueState.count));
+ }
+}
+
+@Test
+void testWithState() throws Exception {
+ try (ProcessTableFunctionTestHarness<Row> harness =
+ ProcessTableFunctionTestHarness.ofClass(StatefulPTF.class)
+ .withTableArgument("input", DataTypes.of("ROW<name STRING, value INT>"))
+ .withPartitionBy("input", "name")
+ .build()) {
+
+ harness.processElement(Row.of("Alice", 10));
+ harness.processElement(Row.of("Alice", 20));
+
+ List<Row> output = harness.getOutput();
+ assertThat(output.get(0)).isEqualTo(Row.of("Alice", 1L));
+ assertThat(output.get(1)).isEqualTo(Row.of("Alice", 2L));
+ }
+}
+```
+{{< /tab >}}
+{{< /tabs >}}
+
+**Initial State Setup**: Use `.withInitialStateForKey()` to pre-populate state
before processing.
+State initialization is scoped per partition key:
+
+{{< tabs "initial-state" >}}
+{{< tab "Java" >}}
+```java
+@Test
+void testWithInitialState() throws Exception {
+ // Value state
+ StatefulPTF.ValueState initialValue = new StatefulPTF.ValueState();
+ initialValue.count = 100L;
+
+ // Row state
+ Row initialRow = Row.withNames();
+ initialRow.setField("lastValue", 42);
+
+ // ListView state
+ ListView<Integer> initialList = new ListView<>();
+ initialList.add(10);
+ initialList.add(20);
+
+ // MapView state
+ MapView<String, Integer> initialMap = new MapView<>();
+ initialMap.put("Alice", 5);
+
+ try (ProcessTableFunctionTestHarness<Row> harness =
+ ProcessTableFunctionTestHarness.ofClass(StatefulPTF.class)
+ .withTableArgument("input", DataTypes.of("ROW<name STRING, value INT>"))
+ .withPartitionBy("input", "name")
+ // Initial state is set per partition key
+ .withInitialStateForKey("valueState", Row.of("Alice"), initialValue)
+ .withInitialStateForKey("rowState", Row.of("Alice"), initialRow)
+ .withInitialStateForKey("listState", Row.of("Alice"), initialList)
+ .withInitialStateForKey("mapState", Row.of("Alice"), initialMap)
+ .build()) {
+
+ harness.processElement(Row.of("Alice", 10));
+
+ List<Row> output = harness.getOutput();
+ assertThat(output).containsExactly(Row.of("Alice", 101L));
+ }
+}
+```
+{{< /tab >}}
+{{< /tabs >}}
+
+**State Introspection**: Use `getStateForKey()`, `getKeysForState()`, and
`getStateForAllKeys()` to
+inspect state during tests:
+
+{{< tabs "state-introspection" >}}
+{{< tab "Java" >}}
+```java
+@Test
+void testStateIntrospection() throws Exception {
+ try (ProcessTableFunctionTestHarness<Row> harness =
+ ProcessTableFunctionTestHarness.ofClass(StatefulPTF.class)
+ .withTableArgument("input", DataTypes.of("ROW<name STRING, value INT>"))
+ .withPartitionBy("input", "name")
+ .build()) {
+
+ harness.processElement(Row.of("Alice", 10));
+ harness.processElement(Row.of("Bob", 20));
+
+ // Check value state
+ StatefulPTF.ValueState aliceState =
+ harness.getStateForKey("valueState", Row.of("Alice"));
+ assertThat(aliceState.count).isEqualTo(1L);
+
+ // Check Row state
+ Row aliceRowState = harness.getStateForKey("rowState", Row.of("Alice"));
+ assertThat(aliceRowState.getField("lastValue")).isEqualTo(10);
+
+ // Check ListView state
+ ListView<Integer> aliceList = harness.getStateForKey("listState",
Row.of("Alice"));
+ assertThat(aliceList.getList()).containsExactly(10);
+
+ // Check MapView state
+ MapView<String, Integer> aliceMap = harness.getStateForKey("mapState",
Row.of("Alice"));
+ assertThat(aliceMap.get("Alice")).isEqualTo(1);
+
+ // Get all partition keys with state
+ Set<Row> keys = harness.getKeysForState("valueState");
+ assertThat(keys).containsExactlyInAnyOrder(Row.of("Alice"), Row.of("Bob"));
+
+ // Get all state across partition keys
+ Map<Row, StatefulPTF.ValueState> allState =
+ harness.getStateForAllKeys("valueState");
+ assertThat(allState.get(Row.of("Bob")).count).isEqualTo(1L);
+ }
+}
+```
+{{< /tab >}}
+{{< /tabs >}}
+
+**State Mutation**: Use `setStateForKey()`, `clearAllStatesForKey()`, and
`clearStateForKey()` to
+modify state during tests:
+
+{{< tabs "state-mutation" >}}
+{{< tab "Java" >}}
+```java
+@Test
+void testStateMutation() throws Exception {
+ try (ProcessTableFunctionTestHarness<Row> harness =
+ ProcessTableFunctionTestHarness.ofClass(StatefulPTF.class)
+ .withTableArgument("input", DataTypes.of("ROW<name STRING, value INT>"))
+ .withPartitionBy("input", "name")
+ .build()) {
+
+ harness.processElement(Row.of("Alice", 10));
+
+ // Overwrite a specific state entry for a partition key
+ StatefulPTF.ValueState newState = new StatefulPTF.ValueState();
+ newState.count = 100L;
+ harness.setStateForKey("valueState", Row.of("Alice"), newState);
+
+ // Verify the state was updated
+ StatefulPTF.ValueState state = harness.getStateForKey("valueState",
Row.of("Alice"));
+ assertThat(state.count).isEqualTo(100L);
+
+ // Clear a specific state entry (resets to default)
+ harness.clearStateForKey("listState", Row.of("Alice"));
+
+ // Clear all state for a partition key
+ harness.clearAllStatesForKey(Row.of("Alice"));
+ }
+}
+```
+{{< /tab >}}
+{{< /tabs >}}
+
+#### Optional Partitioning
+
+For PTFs with `OPTIONAL_PARTITION_BY`, you can omit `withPartitionBy()` during
harness setup. The
+harness executes the function as if it had a parallelism of 1, with the
default `Row.of()` key,
+so all data is routed through the same function instance. Use `Row.of()` to
access state:
+
+{{< tabs "optional-partition" >}}
+{{< tab "Java" >}}
+```java
+// A PTF with optional partitioning that counts all rows.
+@DataTypeHint("ROW<count BIGINT>")
+public class GlobalCountPTF extends ProcessTableFunction<Row> {
+ public static class CountState {
+ public long count = 0L;
+ }
+
+ public void eval(
+ @StateHint CountState state,
+ @ArgumentHint({ArgumentTrait.SET_SEMANTIC_TABLE,
ArgumentTrait.OPTIONAL_PARTITION_BY})
+ Row input) {
+ state.count++;
+ collect(Row.of(state.count));
+ }
+}
+
+@Test
+void testOptionalPartitionWithoutPartitionBy() throws Exception {
+ try (ProcessTableFunctionTestHarness<Row> harness =
+ ProcessTableFunctionTestHarness.ofClass(GlobalCountPTF.class)
+ .withTableArgument("input", DataTypes.of("ROW<key STRING, value INT>"))
+ .build()) {
+
+ harness.processElement(Row.of("A", 10));
+ harness.processElement(Row.of("B", 20));
+
+ // All data shares a single Row.of() partition key
+ GlobalCountPTF.CountState state = harness.getStateForKey("state",
Row.of());
+ assertThat(state.count).isEqualTo(2L);
+ }
+}
+```
+{{< /tab >}}
+{{< /tabs >}}
+
#### Configuring Table Argument Types
In contexts where the harness can't infer the table argument types for table
arguments (when using unannotated `Row` inputs,
@@ -2351,8 +2585,8 @@ void testPOJO() throws Exception {
### PTF Features Unsupported by the TestHarness
-- `Context` paramter
-- State (`@StateHint`)
+- `Context` parameter
- Timers (`onTimer`)
- `on_time` / `rowtime`
- Update traits (`SUPPORTS_UPDATES`, `REQUIRE_UPDATE_BEFORE`)
+- State TTL (state is supported but TTL expiration is not yet implemented)
diff --git
a/flink-table/flink-table-test-utils/src/main/java/org/apache/flink/table/runtime/functions/ListViewStateConverter.java
b/flink-table/flink-table-test-utils/src/main/java/org/apache/flink/table/runtime/functions/ListViewStateConverter.java
new file mode 100644
index 00000000000..07ed63309a8
--- /dev/null
+++
b/flink-table/flink-table-test-utils/src/main/java/org/apache/flink/table/runtime/functions/ListViewStateConverter.java
@@ -0,0 +1,79 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.runtime.functions;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.table.api.dataview.ListView;
+import org.apache.flink.table.data.ArrayData;
+import org.apache.flink.table.data.GenericArrayData;
+import org.apache.flink.table.data.conversion.DataStructureConverter;
+import org.apache.flink.table.types.logical.ArrayType;
+
+import java.util.ArrayList;
+import java.util.List;
+
+/**
+ * Converter for ListView state.
+ *
+ * <p>Converts between external ListView objects and internal ArrayData
representation.
+ */
+@Internal
+class ListViewStateConverter implements StateConverter {
+
+ private final DataStructureConverter<Object, Object> elementConverter;
+ private final ArrayData.ElementGetter elementGetter;
+
+ ListViewStateConverter(
+ ArrayType arrayType, DataStructureConverter<Object, Object>
elementConverter) {
+ this.elementConverter = elementConverter;
+ this.elementGetter =
ArrayData.createElementGetter(arrayType.getElementType());
+ }
+
+ @Override
+ public Object toInternal(Object external) {
+ ListView<?> listView = (ListView<?>) external;
+ List<?> elements = listView.getList();
+
+ Object[] internalArray = new Object[elements.size()];
+ for (int i = 0; i < elements.size(); i++) {
+ internalArray[i] = elementConverter.toInternal(elements.get(i));
+ }
+ return new GenericArrayData(internalArray);
+ }
+
+ @Override
+ public Object toExternal(Object internal) {
+ ArrayData arrayData = (ArrayData) internal;
+ ListView<Object> listView = new ListView<>();
+
+ List<Object> elements = new ArrayList<>();
+ for (int i = 0; i < arrayData.size(); i++) {
+ Object internalElement = elementGetter.getElementOrNull(arrayData,
i);
+ Object externalElement =
elementConverter.toExternal(internalElement);
+ elements.add(externalElement);
+ }
+ listView.setList(elements);
+ return listView;
+ }
+
+ @Override
+ public Object createNewInternalState() {
+ return new GenericArrayData(new Object[0]);
+ }
+}
diff --git
a/flink-table/flink-table-test-utils/src/main/java/org/apache/flink/table/runtime/functions/MapViewStateConverter.java
b/flink-table/flink-table-test-utils/src/main/java/org/apache/flink/table/runtime/functions/MapViewStateConverter.java
new file mode 100644
index 00000000000..b173f8d090c
--- /dev/null
+++
b/flink-table/flink-table-test-utils/src/main/java/org/apache/flink/table/runtime/functions/MapViewStateConverter.java
@@ -0,0 +1,92 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.runtime.functions;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.table.api.dataview.MapView;
+import org.apache.flink.table.data.ArrayData;
+import org.apache.flink.table.data.GenericMapData;
+import org.apache.flink.table.data.MapData;
+import org.apache.flink.table.data.conversion.DataStructureConverter;
+import org.apache.flink.table.types.logical.MapType;
+
+import java.util.HashMap;
+import java.util.Map;
+
+/**
+ * Converter for MapView state.
+ *
+ * <p>Converts between external MapView objects and internal MapData
representation.
+ */
+@Internal
+class MapViewStateConverter implements StateConverter {
+
+ private final DataStructureConverter<Object, Object> keyConverter;
+ private final DataStructureConverter<Object, Object> valueConverter;
+ private final ArrayData.ElementGetter keyGetter;
+ private final ArrayData.ElementGetter valueGetter;
+
+ MapViewStateConverter(
+ MapType mapType,
+ DataStructureConverter<Object, Object> keyConverter,
+ DataStructureConverter<Object, Object> valueConverter) {
+ this.keyConverter = keyConverter;
+ this.valueConverter = valueConverter;
+ this.keyGetter = ArrayData.createElementGetter(mapType.getKeyType());
+ this.valueGetter =
ArrayData.createElementGetter(mapType.getValueType());
+ }
+
+ @Override
+ public Object toInternal(Object external) {
+ MapView<?, ?> mapView = (MapView<?, ?>) external;
+ Map<?, ?> entries = mapView.getMap();
+
+ Map<Object, Object> internalMap = new HashMap<>();
+ for (Map.Entry<?, ?> entry : entries.entrySet()) {
+ Object internalKey = keyConverter.toInternal(entry.getKey());
+ Object internalValue = valueConverter.toInternal(entry.getValue());
+ internalMap.put(internalKey, internalValue);
+ }
+ return new GenericMapData(internalMap);
+ }
+
+ @Override
+ public Object toExternal(Object internal) {
+ MapData mapData = (MapData) internal;
+ MapView<Object, Object> mapView = new MapView<>();
+
+ Map<Object, Object> entries = new HashMap<>();
+ ArrayData keyArray = mapData.keyArray();
+ ArrayData valueArray = mapData.valueArray();
+ for (int i = 0; i < keyArray.size(); i++) {
+ Object internalKey = keyGetter.getElementOrNull(keyArray, i);
+ Object internalValue = valueGetter.getElementOrNull(valueArray, i);
+ Object externalKey = keyConverter.toExternal(internalKey);
+ Object externalValue = valueConverter.toExternal(internalValue);
+ entries.put(externalKey, externalValue);
+ }
+ mapView.setMap(entries);
+ return mapView;
+ }
+
+ @Override
+ public Object createNewInternalState() {
+ return new GenericMapData(new HashMap<>());
+ }
+}
diff --git
a/flink-table/flink-table-test-utils/src/main/java/org/apache/flink/table/runtime/functions/ProcessTableFunctionTestHarness.java
b/flink-table/flink-table-test-utils/src/main/java/org/apache/flink/table/runtime/functions/ProcessTableFunctionTestHarness.java
index 425c88c3753..5a6eecb50f8 100644
---
a/flink-table/flink-table-test-utils/src/main/java/org/apache/flink/table/runtime/functions/ProcessTableFunctionTestHarness.java
+++
b/flink-table/flink-table-test-utils/src/main/java/org/apache/flink/table/runtime/functions/ProcessTableFunctionTestHarness.java
@@ -20,8 +20,8 @@ package org.apache.flink.table.runtime.functions;
import org.apache.flink.annotation.PublicEvolving;
import org.apache.flink.table.annotation.ArgumentTrait;
-import org.apache.flink.table.annotation.StateHint;
import org.apache.flink.table.catalog.DataTypeFactory;
+import org.apache.flink.table.data.RowData;
import org.apache.flink.table.data.conversion.DataStructureConverter;
import org.apache.flink.table.data.conversion.DataStructureConverters;
import org.apache.flink.table.functions.FunctionContext;
@@ -31,12 +31,15 @@ import org.apache.flink.table.functions.TableSemantics;
import org.apache.flink.table.types.AbstractDataType;
import org.apache.flink.table.types.DataType;
import org.apache.flink.table.types.FieldsDataType;
+import org.apache.flink.table.types.inference.StateTypeStrategy;
import org.apache.flink.table.types.inference.StaticArgument;
import org.apache.flink.table.types.inference.StaticArgumentTrait;
import org.apache.flink.table.types.inference.SystemTypeInference;
import org.apache.flink.table.types.inference.TypeInference;
import org.apache.flink.table.types.inference.TypeStrategy;
+import org.apache.flink.table.types.logical.ArrayType;
import org.apache.flink.table.types.logical.LogicalType;
+import org.apache.flink.table.types.logical.MapType;
import org.apache.flink.table.types.logical.RowType;
import org.apache.flink.table.types.logical.StructuredType;
import org.apache.flink.table.types.utils.TypeConversions;
@@ -47,6 +50,7 @@ import org.apache.flink.util.Collector;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.lang.reflect.Parameter;
+import java.time.Duration;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.EnumSet;
@@ -56,6 +60,7 @@ import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
+import java.util.stream.Collectors;
import static org.apache.flink.util.Preconditions.checkArgument;
import static org.apache.flink.util.Preconditions.checkNotNull;
@@ -85,16 +90,16 @@ import static
org.apache.flink.util.Preconditions.checkState;
@PublicEvolving
public class ProcessTableFunctionTestHarness<OUT> implements AutoCloseable {
- /** Holds input and output converters for a table argument. */
- private static class ConverterPair {
- final DataStructureConverter<Object, Object> input;
- final DataStructureConverter<Object, Object> output;
+ /** Holds converters for transforming table argument input rows. */
+ private static class TableArgumentConverters {
+ final DataStructureConverter<Object, Object> toNamedRow;
+ final DataStructureConverter<Object, Object> toEvalArgument;
- ConverterPair(
- DataStructureConverter<Object, Object> input,
- DataStructureConverter<Object, Object> output) {
- this.input = input;
- this.output = output;
+ TableArgumentConverters(
+ DataStructureConverter<Object, Object> toNamedRow,
+ DataStructureConverter<Object, Object> toEvalArgument) {
+ this.toNamedRow = toNamedRow;
+ this.toEvalArgument = toEvalArgument;
}
}
@@ -103,6 +108,7 @@ public class ProcessTableFunctionTestHarness<OUT>
implements AutoCloseable {
private final List<OUT> output;
private boolean isOpen;
private final HarnessCollector collector;
+ private final TestHarnessStateManager stateManager;
private final String defaultTableArgument;
private final Method evalMethod;
@@ -110,10 +116,9 @@ public class ProcessTableFunctionTestHarness<OUT>
implements AutoCloseable {
private final Map<String, ArgumentInfo> argumentsByName;
private final boolean isSingleTableFunction;
- private final Map<String, Object> scalarArgumentValues;
private boolean hasTableArguments = false;
- private final Map<String, ConverterPair> argumentConverters;
+ private final Map<String, TableArgumentConverters> argumentConverters;
private final DataStructureConverter<Object, Object>
harnessOutputConverter;
private ProcessTableFunctionTestHarness(
@@ -121,17 +126,17 @@ public class ProcessTableFunctionTestHarness<OUT>
implements AutoCloseable {
FunctionContext functionContext,
Method evalMethod,
List<ArgumentInfo> arguments,
- Map<String, Object> scalarArgumentValues,
- Map<String, ConverterPair> argumentConverters,
- DataStructureConverter<Object, Object> harnessOutputConverter)
+ Map<String, TableArgumentConverters> argumentConverters,
+ DataStructureConverter<Object, Object> harnessOutputConverter,
+ TestHarnessStateManager stateManager)
throws Exception {
this.function = function;
this.functionContext = functionContext;
this.evalMethod = evalMethod;
this.arguments = arguments;
- this.scalarArgumentValues = scalarArgumentValues;
this.argumentConverters = argumentConverters;
this.harnessOutputConverter = harnessOutputConverter;
+ this.stateManager = stateManager;
this.output = new ArrayList<>();
this.collector = new HarnessCollector();
this.isOpen = false;
@@ -143,13 +148,8 @@ public class ProcessTableFunctionTestHarness<OUT>
implements AutoCloseable {
}
}
- final List<ArgumentInfo> tableArguments = new ArrayList<>();
- for (ArgumentInfo arg : arguments) {
- if (arg.isTableArgument) {
- tableArguments.add(arg);
- this.hasTableArguments = true;
- }
- }
+ final List<TableArgumentInfo> tableArguments =
ArgumentInfo.filterTableArguments(arguments);
+ this.hasTableArguments = !tableArguments.isEmpty();
if (tableArguments.size() == 1) {
this.defaultTableArgument = tableArguments.get(0).name;
@@ -213,11 +213,14 @@ public class ProcessTableFunctionTestHarness<OUT>
implements AutoCloseable {
checkState(isOpen, "Harness not open");
checkNotNull(tableArgument, "tableArgument must not be null");
- ArgumentInfo tableArg = argumentsByName.get(tableArgument);
- if (tableArg == null) {
+ ArgumentInfo arg = argumentsByName.get(tableArgument);
+ if (arg == null) {
throw new IllegalArgumentException("Unknown table argument: " +
tableArgument);
}
- invokeEval(tableArg, row);
+ if (!(arg instanceof TableArgumentInfo)) {
+ throw new IllegalArgumentException("'" + tableArgument + "' is not
a table argument");
+ }
+ invokeEval((TableArgumentInfo) arg, row);
}
/** Process a single element for a specific table argument. */
@@ -250,7 +253,7 @@ public class ProcessTableFunctionTestHarness<OUT>
implements AutoCloseable {
+ "Use processElement() or
processElementForTable() instead.");
}
- Object[] args = arguments.stream().map(arg ->
scalarArgumentValues.get(arg.name)).toArray();
+ Object[] args = arguments.stream().map(arg -> ((ScalarArgumentInfo)
arg).value).toArray();
try {
evalMethod.invoke(function, args);
@@ -270,43 +273,69 @@ public class ProcessTableFunctionTestHarness<OUT>
implements AutoCloseable {
output.clear();
}
- /**
- * Given a target table argument and a row to process, construct the right
set of arguments for
- * the PTF's eval function and attempt to invoke it.
- */
- private void invokeEval(ArgumentInfo activeTableArg, Row activeRow) throws
Exception {
- // Set collector context so it can prepend columns if needed
- collector.setContext(activeTableArg, activeRow);
+ /** Get state for a specific partition key. */
+ public <T> T getStateForKey(String stateName, Row partitionKey) {
+ return stateManager.getStateForKey(stateName, partitionKey);
+ }
- Object[] args = new Object[arguments.size()];
+ /** Set state for a specific partition key. */
+ public void setStateForKey(String stateName, Row partitionKey, Object
state) throws Exception {
+ stateManager.setStateForKey(stateName, partitionKey, state);
+ }
- for (int i = 0; i < arguments.size(); i++) {
- ArgumentInfo arg = arguments.get(i);
+ /** Get all partition keys that have a specific state entry. */
+ public Set<Row> getKeysForState(String stateName) {
+ return stateManager.getKeysForState(stateName);
+ }
+
+ /** Get all state values for a state name across all partition keys. */
+ public <T> Map<Row, T> getStateForAllKeys(String stateName) {
+ return stateManager.getStateForAllKeys(stateName);
+ }
+
+ /** Clear all state for a given partition key. */
+ public void clearAllStatesForKey(Row partitionKey) {
+ stateManager.clearAllStatesForKey(partitionKey);
+ }
- if (arg.isTableArgument && arg.name.equals(activeTableArg.name)) {
- // If the argument is the active table argument, first convert
the input row
- // to an internal RowData type, and then convert the RowData
to type that the
- // argument expects. For Rows, this will structure the Row
based on the table
- // argument structure. Otherwise, for POJOs, it will pass the
expected POJO to eval.
+ /** Clear specific state entry for a given partition key. */
+ public void clearStateForKey(String stateName, Row partitionKey) {
+ stateManager.clearStateForKey(stateName, partitionKey);
+ }
- ConverterPair pair = argumentConverters.get(arg.name);
+ private void invokeEval(TableArgumentInfo activeTableArg, Row activeRow)
throws Exception {
+ TableArgumentConverters converters =
argumentConverters.get(activeTableArg.name);
- args[i] =
pair.output.toExternalOrNull(pair.input.toInternalOrNull(activeRow));
+ RowData rowData = (RowData)
converters.toNamedRow.toInternal(activeRow);
+ Row namedRow = (Row) converters.toNamedRow.toExternal(rowData);
+ Object evalArgument = converters.toEvalArgument.toExternal(rowData);
- } else if (arg.isScalar) {
- args[i] = scalarArgumentValues.get(arg.name);
+ collector.setContext(activeTableArg, namedRow);
- } else if (arg.isTableArgument) {
- // Inactive table arguments receive null
- args[i] = null;
- } else {
- throw new IllegalStateException(
- "Unexpected argument type at position " + i + ": " +
arg.name);
+ Row partitionKey = extractPartitionKey(activeTableArg, namedRow);
+ Map<String, Object> stateMap =
stateManager.loadStateForKey(partitionKey);
+
+ Object[] args = new Object[arguments.size()];
+ int i = 0;
+
+ for (ArgumentInfo arg : arguments) {
+ if (arg instanceof StateArgumentInfo) {
+ args[i++] = stateMap.get(arg.name);
+ } else if (arg instanceof TableArgumentInfo) {
+ TableArgumentInfo tableArg = (TableArgumentInfo) arg;
+ if (tableArg.name.equals(activeTableArg.name)) {
+ args[i++] = evalArgument;
+ } else {
+ args[i++] = null;
+ }
+ } else if (arg instanceof ScalarArgumentInfo) {
+ args[i++] = ((ScalarArgumentInfo) arg).value;
}
}
try {
evalMethod.invoke(function, args);
+ stateManager.updateStateForKey(partitionKey, stateMap);
} catch (InvocationTargetException e) {
String partitionInfo =
activeTableArg.partitionColumnNames != null
@@ -323,6 +352,16 @@ public class ProcessTableFunctionTestHarness<OUT>
implements AutoCloseable {
}
}
+ private Row extractPartitionKey(TableArgumentInfo tableArg, Row row) {
+ if (tableArg.partitionColumnNames == null ||
tableArg.partitionColumnNames.length == 0) {
+ return Row.of();
+ }
+
+ Object[] keyValues =
+
Arrays.stream(tableArg.partitionColumnNames).map(row::getField).toArray();
+ return Row.of(keyValues);
+ }
+
/** Collector implementation that stores output in the harness. */
private class HarnessCollector implements Collector<OUT> {
private ArgumentInfo activeTableArg;
@@ -337,10 +376,11 @@ public class ProcessTableFunctionTestHarness<OUT>
implements AutoCloseable {
public void collect(OUT record) {
OUT finalRecord;
- if (activeTableArg == null || !activeTableArg.isTableArgument) {
+ if (activeTableArg == null || !(activeTableArg instanceof
TableArgumentInfo)) {
finalRecord = record;
} else {
- switch (activeTableArg.prependStrategy) {
+ TableArgumentInfo tableArg = (TableArgumentInfo)
activeTableArg;
+ switch (tableArg.prependStrategy) {
case ALL_COLUMNS:
finalRecord = prependAllColumns(record);
break;
@@ -352,7 +392,7 @@ public class ProcessTableFunctionTestHarness<OUT>
implements AutoCloseable {
break;
default:
throw new IllegalStateException(
- "Unknown prepend strategy: " +
activeTableArg.prependStrategy);
+ "Unknown prepend strategy: " +
tableArg.prependStrategy);
}
}
@@ -384,8 +424,11 @@ public class ProcessTableFunctionTestHarness<OUT>
implements AutoCloseable {
int totalPartitionKeyCount = 0;
for (ArgumentInfo arg : arguments) {
- if (arg.isSetSemantic && arg.partitionColumnNames != null) {
- totalPartitionKeyCount += arg.partitionColumnNames.length;
+ if (arg instanceof TableArgumentInfo) {
+ TableArgumentInfo tableArg = (TableArgumentInfo) arg;
+ if (tableArg.isSetSemantic &&
tableArg.partitionColumnNames != null) {
+ totalPartitionKeyCount +=
tableArg.partitionColumnNames.length;
+ }
}
}
@@ -394,19 +437,22 @@ public class ProcessTableFunctionTestHarness<OUT>
implements AutoCloseable {
Row result = new Row(ptfRow.getKind(), totalArity);
- // Extract partition key values from active row
- Object[] partitionKeyValues = new
Object[activeTableArg.partitionColumnNames.length];
- for (int i = 0; i < activeTableArg.partitionColumnNames.length;
i++) {
- String columnName = activeTableArg.partitionColumnNames[i];
- int columnIndex = getFieldIndex(activeTableArg.dataType,
columnName);
+ TableArgumentInfo activeTableInfo = (TableArgumentInfo)
activeTableArg;
+ Object[] partitionKeyValues = new
Object[activeTableInfo.partitionColumnNames.length];
+ for (int i = 0; i < activeTableInfo.partitionColumnNames.length;
i++) {
+ String columnName = activeTableInfo.partitionColumnNames[i];
+ int columnIndex = getFieldIndex(activeTableInfo.dataType,
columnName);
partitionKeyValues[i] = activeRow.getField(columnIndex);
}
int resultIndex = 0;
for (ArgumentInfo arg : arguments) {
- if (arg.isSetSemantic && arg.partitionColumnNames != null) {
- for (int i = 0; i < arg.partitionColumnNames.length; i++) {
- result.setField(resultIndex++, partitionKeyValues[i]);
+ if (arg instanceof TableArgumentInfo) {
+ TableArgumentInfo tableArg = (TableArgumentInfo) arg;
+ if (tableArg.isSetSemantic &&
tableArg.partitionColumnNames != null) {
+ for (int i = 0; i <
tableArg.partitionColumnNames.length; i++) {
+ result.setField(resultIndex++,
partitionKeyValues[i]);
+ }
}
}
}
@@ -496,6 +542,7 @@ public class ProcessTableFunctionTestHarness<OUT>
implements AutoCloseable {
private final Map<String, ScalarArgumentConfiguration> scalarArgs =
new HashMap<>();
private final Map<String, TableArgumentConfiguration> tableArgs = new
HashMap<>();
private final Map<String, PartitionConfiguration> partitionConfigs =
new HashMap<>();
+ private final Map<String, StateArgumentConfiguration> stateArgs = new
HashMap<>();
private Builder(Class<? extends ProcessTableFunction<OUT>>
functionClass) {
this.functionClass = checkNotNull(functionClass, "functionClass
must not be null");
@@ -568,6 +615,20 @@ public class ProcessTableFunctionTestHarness<OUT>
implements AutoCloseable {
return this;
}
+ /** Sets initial state for a state parameter. */
+ public Builder<OUT> withInitialStateForKey(
+ String stateName, Row partitionKey, Object state) {
+ checkNotNull(stateName, "stateName must not be null");
+ checkNotNull(partitionKey, "partitionKey must not be null");
+ checkNotNull(state, "state must not be null");
+
+ stateArgs
+ .computeIfAbsent(stateName, k -> new
StateArgumentConfiguration())
+ .initialValues
+ .put(partitionKey, state);
+ return this;
+ }
+
//
---------------------------------------------------------------------
// Partitioning
//
---------------------------------------------------------------------
@@ -610,6 +671,8 @@ public class ProcessTableFunctionTestHarness<OUT>
implements AutoCloseable {
public ProcessTableFunctionTestHarness<OUT> build() throws Exception {
ProcessTableFunction<OUT> function = instantiateFunction();
+ ClassLoader classLoader =
Thread.currentThread().getContextClassLoader();
+
DataTypeFactory dataTypeFactory = createDataTypeFactory();
TypeInference baseTypeInference =
function.getTypeInference(dataTypeFactory);
TypeInference systemTypeInference =
@@ -618,21 +681,42 @@ public class ProcessTableFunctionTestHarness<OUT>
implements AutoCloseable {
List<ArgumentInfo> arguments =
extractAndValidateTypeInference(function,
systemTypeInference);
- FunctionContext functionContext =
- new FunctionContext(null,
Thread.currentThread().getContextClassLoader(), null);
+ FunctionContext functionContext = new FunctionContext(null,
classLoader, null);
Method evalMethod = findEvalMethod();
validateEvalMethodSupported(evalMethod, arguments);
validatePartitionConsistency(arguments);
+ validateInitialStateKeys(arguments);
+
+ Map<String, TableArgumentConverters> argumentConverters = new
HashMap<>();
+ Map<String, StateConverter> stateConverters = new HashMap<>();
+ createConverters(arguments, argumentConverters, stateConverters,
classLoader);
+
+ // Create state manager
+ List<StateArgumentInfo> stateArguments =
ArgumentInfo.filterStateArguments(arguments);
+ TestHarnessStateManager stateManager =
+ new TestHarnessStateManager(
+ stateArguments, stateConverters,
extractPartitionKeyInfo(arguments));
+
+ // Populate initial state
+ for (Map.Entry<String, StateArgumentConfiguration> entry :
stateArgs.entrySet()) {
+ String stateName = entry.getKey();
+ for (Map.Entry<Row, Object> stateEntry :
+ entry.getValue().initialValues.entrySet()) {
+ stateManager.setStateForKey(
+ stateName, stateEntry.getKey(),
stateEntry.getValue());
+ }
+ }
- Map<String, ConverterPair> argumentConverters = new HashMap<>();
- createConverters(arguments, argumentConverters);
+ // Extract table arguments for output type derivation
+ // SystemTypeInference needs table semantics for pass-through
column deduplication
+ List<TableArgumentInfo> tableArgs =
ArgumentInfo.filterTableArguments(arguments);
- // Derive output schema using SystemTypeInference (includes
deduplication)
+ // Derive output schema using SystemTypeInference
DataType derivedOutputType =
deriveOutputTypeFromSystemInference(
- function, dataTypeFactory, systemTypeInference,
arguments);
+ function, dataTypeFactory, systemTypeInference,
tableArgs);
// Create output converter for PTF emissions
DataStructureConverter<Object, Object> harnessOutputConverter =
@@ -643,23 +727,9 @@ public class ProcessTableFunctionTestHarness<OUT>
implements AutoCloseable {
functionContext,
evalMethod,
arguments,
- extractScalarValues(arguments),
argumentConverters,
- harnessOutputConverter);
- }
-
- /** Extracts scalar values from configs, creating a map keyed by
argument name. */
- private Map<String, Object> extractScalarValues(List<ArgumentInfo>
arguments) {
- Map<String, Object> values = new HashMap<>();
- for (ArgumentInfo arg : arguments) {
- if (arg.isScalar) {
- ScalarArgumentConfiguration config =
scalarArgs.get(arg.name);
- if (config != null) {
- values.put(arg.name, config.value);
- }
- }
- }
- return values;
+ harnessOutputConverter,
+ stateManager);
}
/**
@@ -675,58 +745,60 @@ public class ProcessTableFunctionTestHarness<OUT>
implements AutoCloseable {
}
/**
- * Creates and initializes data structure converters for all table
arguments.
+ * Creates and initializes converters for all table and state
arguments.
*
- * <p>For Row types, both input and output converters are the same
(between Row and
- * RowData).
- *
- * <p>For structured types, input converter uses Row types (Row to
RowData), and the output
- * converter uses the structured type.
+ * <p>For table arguments with Row types, both converters are the same
(between Row and
+ * RowData). For structured types, toNamedRow uses Row type (Row to
RowData), and
+ * toEvalArgument uses the structured type.
*/
private void createConverters(
- List<ArgumentInfo> arguments, Map<String, ConverterPair>
argumentConverters) {
- ClassLoader classLoader =
Thread.currentThread().getContextClassLoader();
+ List<ArgumentInfo> arguments,
+ Map<String, TableArgumentConverters> argumentConverters,
+ Map<String, StateConverter> stateConverters,
+ ClassLoader classLoader)
+ throws Exception {
+
+ for (StateArgumentInfo stateArg :
ArgumentInfo.filterStateArguments(arguments)) {
+ StateConverter converter =
createStateConverter(stateArg.dataType, classLoader);
+ stateConverters.put(stateArg.name, converter);
+ }
- for (ArgumentInfo arg : arguments) {
- if (arg.isTableArgument) {
- String converterKey = arg.name;
-
- LogicalType logicalType = arg.dataType.getLogicalType();
- boolean isStructuredType =
- logicalType instanceof StructuredType
- && ((StructuredType) logicalType)
- .getImplementationClass()
- .isPresent();
-
- if (isStructuredType) {
- StructuredType structuredType = (StructuredType)
logicalType;
- List<RowType.RowField> rowFields = new ArrayList<>();
- for (StructuredType.StructuredAttribute attr :
- structuredType.getAttributes()) {
- rowFields.add(new RowType.RowField(attr.getName(),
attr.getType()));
- }
- RowType rowType = new
RowType(logicalType.isNullable(), rowFields);
- DataType rowDataType =
TypeConversions.fromLogicalToDataType(rowType);
+ for (TableArgumentInfo tableArg :
ArgumentInfo.filterTableArguments(arguments)) {
+ String converterKey = tableArg.name;
+
+ LogicalType logicalType = tableArg.dataType.getLogicalType();
+ boolean isStructuredType =
+ logicalType instanceof StructuredType
+ && ((StructuredType) logicalType)
+ .getImplementationClass()
+ .isPresent();
+
+ if (isStructuredType) {
+ StructuredType structuredType = (StructuredType)
logicalType;
+ List<RowType.RowField> rowFields = new ArrayList<>();
+ for (StructuredType.StructuredAttribute attr :
structuredType.getAttributes()) {
+ rowFields.add(new RowType.RowField(attr.getName(),
attr.getType()));
+ }
+ RowType rowType = new RowType(logicalType.isNullable(),
rowFields);
+ DataType rowDataType =
TypeConversions.fromLogicalToDataType(rowType);
- DataStructureConverter<Object, Object> inputConverter =
-
DataStructureConverters.getConverter(rowDataType);
- inputConverter.open(classLoader);
+ DataStructureConverter<Object, Object> toNamedRow =
+ DataStructureConverters.getConverter(rowDataType);
+ toNamedRow.open(classLoader);
- DataStructureConverter<Object, Object> outputConverter
=
-
DataStructureConverters.getConverter(arg.dataType);
- outputConverter.open(classLoader);
+ DataStructureConverter<Object, Object> toEvalArgument =
+
DataStructureConverters.getConverter(tableArg.dataType);
+ toEvalArgument.open(classLoader);
- argumentConverters.put(
- converterKey, new
ConverterPair(inputConverter, outputConverter));
- } else {
- // For Row types, input and output converters are the
same
- DataStructureConverter<Object, Object> converter =
-
DataStructureConverters.getConverter(arg.dataType);
- converter.open(classLoader);
+ argumentConverters.put(
+ converterKey, new
TableArgumentConverters(toNamedRow, toEvalArgument));
+ } else {
+ DataStructureConverter<Object, Object> converter =
+
DataStructureConverters.getConverter(tableArg.dataType);
+ converter.open(classLoader);
- argumentConverters.put(
- converterKey, new ConverterPair(converter,
converter));
- }
+ argumentConverters.put(
+ converterKey, new
TableArgumentConverters(converter, converter));
}
}
}
@@ -757,8 +829,8 @@ public class ProcessTableFunctionTestHarness<OUT>
implements AutoCloseable {
}
/**
- * Validates that the eval() method doesn't use unsupported features.
Temporary, until state
- * and context is supported.
+ * Validates that the eval() method doesn't use unsupported features.
Temporary, until
+ * context is supported.
*/
private void validateEvalMethodSupported(Method evalMethod,
List<ArgumentInfo> arguments) {
Parameter[] parameters = evalMethod.getParameters();
@@ -774,34 +846,32 @@ public class ProcessTableFunctionTestHarness<OUT>
implements AutoCloseable {
+ "Found Context parameter at
position %d in eval() method. ",
i));
}
-
- if (param.isAnnotationPresent(StateHint.class)) {
- throw new IllegalStateException(
- String.format(
- "ProcessTableFunctionTestHarness does not
yet support state parameters. "
- + "Found @StateHint parameter at
position %d in eval() method. ",
- i));
- }
}
if (parameters.length != arguments.size()) {
+ long stateCount =
ArgumentInfo.filterStateArguments(arguments).size();
+ long nonStateCount = arguments.size() - stateCount;
throw new IllegalStateException(
String.format(
- "Parameter count mismatch: eval() has %d
parameters but only %d arguments were extracted. "
- + "This may indicate missing
@ArgumentHint annotations.",
- parameters.length, arguments.size()));
+ "Parameter count mismatch: eval() has %d
parameters but expected %d "
+ + "(%d state + %d table/scalar
arguments). "
+ + "eval() signature: %s. "
+ + "This may indicate missing
@ArgumentHint or @StateHint annotations.",
+ parameters.length,
+ arguments.size(),
+ stateCount,
+ nonStateCount,
+ evalMethod));
}
- for (int i = 0; i < parameters.length; i++) {
+ for (int i = 0; i < arguments.size(); i++) {
Parameter param = parameters[i];
Class<?> paramType = param.getType();
ArgumentInfo arg = arguments.get(i);
- if (arg.isScalar) {
- ScalarArgumentConfiguration config =
scalarArgs.get(arg.name);
- if (config != null
- && config.value != null
- &&
!paramType.isAssignableFrom(config.value.getClass())) {
+ if (arg instanceof ScalarArgumentInfo) {
+ Object value = ((ScalarArgumentInfo) arg).value;
+ if (value != null &&
!paramType.isAssignableFrom(value.getClass())) {
throw new IllegalStateException(
String.format(
"Type mismatch for scalar argument
'%s' at position %d: "
@@ -809,7 +879,7 @@ public class ProcessTableFunctionTestHarness<OUT>
implements AutoCloseable {
arg.name,
i,
paramType.getName(),
- config.value.getClass().getName()));
+ value.getClass().getName()));
}
}
}
@@ -821,10 +891,13 @@ public class ProcessTableFunctionTestHarness<OUT>
implements AutoCloseable {
* matching data types.
*/
private void validatePartitionConsistency(List<ArgumentInfo>
arguments) {
- final List<ArgumentInfo> partitionedTables = new ArrayList<>();
+ final List<TableArgumentInfo> partitionedTables = new
ArrayList<>();
for (ArgumentInfo arg : arguments) {
- if (arg.isSetSemantic && arg.partitionColumnNames != null) {
- partitionedTables.add(arg);
+ if (arg instanceof TableArgumentInfo) {
+ TableArgumentInfo tableArg = (TableArgumentInfo) arg;
+ if (tableArg.isSetSemantic &&
tableArg.partitionColumnNames != null) {
+ partitionedTables.add(tableArg);
+ }
}
}
@@ -832,11 +905,11 @@ public class ProcessTableFunctionTestHarness<OUT>
implements AutoCloseable {
return;
}
- final ArgumentInfo first = partitionedTables.get(0);
+ final TableArgumentInfo first = partitionedTables.get(0);
final int expectedPartitionColumnCount =
first.partitionColumnNames.length;
for (int i = 1; i < partitionedTables.size(); i++) {
- ArgumentInfo current = partitionedTables.get(i);
+ TableArgumentInfo current = partitionedTables.get(i);
if (current.partitionColumnNames.length !=
expectedPartitionColumnCount) {
throw new IllegalArgumentException(
@@ -876,17 +949,77 @@ public class ProcessTableFunctionTestHarness<OUT>
implements AutoCloseable {
}
}
- private DataType extractPartitionColumnType(ArgumentInfo arg, String
columnName) {
- if (!(arg.dataType instanceof FieldsDataType)) {
+ private void validateInitialStateKeys(List<ArgumentInfo> arguments) {
+ if (stateArgs.isEmpty()) {
+ return;
+ }
+
+ // All partitioned tables share the same partition key shape
(ensured by
+ // validatePartitionConsistency()), so any one suffices for
validation.
+ Optional<TableArgumentInfo> partitionedTable =
+ arguments.stream()
+ .filter(arg -> arg instanceof TableArgumentInfo)
+ .map(arg -> (TableArgumentInfo) arg)
+ .filter(t -> t.isSetSemantic &&
t.partitionColumnNames != null)
+ .findFirst();
+
+ if (partitionedTable.isEmpty()) {
+ // In cases of PTFs with OPTIONAL_PARTITION_BY and harness
setups with no partition
+ // setup, all data shares the same Row.of() key, so there is
no schema to validate
+ return;
+ }
+
+ TableArgumentInfo table = partitionedTable.get();
+ int expectedArity = table.partitionColumnNames.length;
+ LogicalType[] expectedTypes =
+ Arrays.stream(table.partitionColumnNames)
+ .map(col -> extractPartitionColumnType(table,
col).getLogicalType())
+ .toArray(LogicalType[]::new);
+
+ for (Map.Entry<String, StateArgumentConfiguration> entry :
stateArgs.entrySet()) {
+ for (Row key : entry.getValue().initialValues.keySet()) {
+ if (key.getArity() != expectedArity) {
+ throw new IllegalArgumentException(
+ String.format(
+ "Initial state key '%s' for state '%s'
has arity %d, "
+ + "but partition key has arity
%d.",
+ key, entry.getKey(), key.getArity(),
expectedArity));
+ }
+
+ for (int i = 0; i < expectedArity; i++) {
+ Object value = key.getField(i);
+ Class<?> expectedClass =
expectedTypes[i].getDefaultConversion();
+ if (value != null && !expectedClass.isInstance(value))
{
+ throw new IllegalArgumentException(
+ String.format(
+ "Initial state key '%s' for state
'%s' has type %s "
+ + "at position %d, but
partition column '%s' "
+ + "expects %s.",
+ key,
+ entry.getKey(),
+ value.getClass().getSimpleName(),
+ i,
+ table.partitionColumnNames[i],
+ expectedClass.getSimpleName()));
+ }
+ }
+ }
+ }
+ }
+
+ private DataType extractPartitionColumnType(TableArgumentInfo
tableArg, String columnName) {
+ if (!(tableArg.dataType instanceof FieldsDataType)) {
throw new IllegalStateException(
String.format(
"Cannot extract data type for partition column
'%s' of argument '%s': "
+ "argument data type is not a
FieldsDataType (actual: %s)",
- columnName, arg.name,
arg.dataType.getClass().getSimpleName()));
+ columnName,
+ tableArg.name,
+ tableArg.dataType.getClass().getSimpleName()));
}
- FieldsDataType fieldsDataType = (FieldsDataType) arg.dataType;
- List<String> fieldNames = getFieldNames(arg.dataType);
+ FieldsDataType fieldsDataType = (FieldsDataType) tableArg.dataType;
+ List<String> fieldNames = getFieldNames(tableArg.dataType);
List<DataType> fieldDataTypes = fieldsDataType.getChildren();
int fieldIndex = fieldNames.indexOf(columnName);
@@ -897,7 +1030,32 @@ public class ProcessTableFunctionTestHarness<OUT>
implements AutoCloseable {
throw new IllegalStateException(
String.format(
"Partition column '%s' not found in argument '%s'",
- columnName, arg.name));
+ columnName, tableArg.name));
+ }
+
+ private TestHarnessStateManager.PartitionKeyInfo
extractPartitionKeyInfo(
+ List<ArgumentInfo> arguments) {
+ // All partitioned tables share the same partition key shape
(ensured by
+ // validatePartitionConsistency()), so any one suffices.
+ Optional<TableArgumentInfo> partitionedTable =
+ arguments.stream()
+ .filter(arg -> arg instanceof TableArgumentInfo)
+ .map(arg -> (TableArgumentInfo) arg)
+ .filter(t -> t.isSetSemantic &&
t.partitionColumnNames != null)
+ .findFirst();
+
+ if (partitionedTable.isEmpty()) {
+ return new TestHarnessStateManager.PartitionKeyInfo(0, null,
null);
+ }
+
+ TableArgumentInfo table = partitionedTable.get();
+ String[] columnNames = table.partitionColumnNames;
+ LogicalType[] columnTypes =
+ Arrays.stream(columnNames)
+ .map(col -> extractPartitionColumnType(table,
col).getLogicalType())
+ .toArray(LogicalType[]::new);
+ return new TestHarnessStateManager.PartitionKeyInfo(
+ columnNames.length, columnNames, columnTypes);
}
//
---------------------------------------------------------------------
@@ -911,7 +1069,8 @@ public class ProcessTableFunctionTestHarness<OUT>
implements AutoCloseable {
* table argument rules, static argument trait validation, etc.
*/
private List<ArgumentInfo> extractAndValidateTypeInference(
- ProcessTableFunction<OUT> function, TypeInference
systemTypeInference) {
+ ProcessTableFunction<OUT> function, TypeInference
systemTypeInference)
+ throws Exception {
Optional<List<StaticArgument>> staticArgsOpt =
systemTypeInference.getStaticArguments();
if (staticArgsOpt.isEmpty()) {
@@ -928,7 +1087,7 @@ public class ProcessTableFunctionTestHarness<OUT>
implements AutoCloseable {
}
}
- List<ArgumentInfo> arguments = new ArrayList<>();
+ List<ArgumentInfo> tableAndScalarArguments = new ArrayList<>();
for (StaticArgument staticArg : userArgs) {
boolean isScalar =
staticArg.getTraits().contains(StaticArgumentTrait.SCALAR);
@@ -940,7 +1099,7 @@ public class ProcessTableFunctionTestHarness<OUT>
implements AutoCloseable {
if (isScalar || isTableArg) {
ArgumentInfo argInfo = buildArgumentInfo(staticArg);
- arguments.add(argInfo);
+ tableAndScalarArguments.add(argInfo);
} else {
throw new IllegalStateException(
"Unknown argument type for StaticArgument. "
@@ -948,9 +1107,98 @@ public class ProcessTableFunctionTestHarness<OUT>
implements AutoCloseable {
}
}
- validateArgumentConfiguration(arguments);
+ validateArgumentConfiguration(tableAndScalarArguments);
+
+ // Extract state arguments from TypeInference
+ List<StateArgumentInfo> stateArguments = new ArrayList<>();
+
+ Map<String, StateTypeStrategy> stateStrategies =
+ systemTypeInference.getStateTypeStrategies();
+
+ DataTypeFactory dataTypeFactory = createDataTypeFactory();
+
+ List<TableArgumentInfo> tableArgs =
+ ArgumentInfo.filterTableArguments(tableAndScalarArguments);
+ List<DataType> argumentDataTypes = new ArrayList<>();
+ for (TableArgumentInfo tArg : tableArgs) {
+ argumentDataTypes.add(tArg.dataType);
+ }
+ Map<Integer, TableSemantics> tableSemanticsMap = new HashMap<>();
+ for (int i = 0; i < tableArgs.size(); i++) {
+ TableArgumentInfo tArg = tableArgs.get(i);
+ int[] partitionIndices = getPartitionColumnIndices(tArg);
+ tableSemanticsMap.put(
+ i, new TestHarnessTableSemantics(tArg.dataType,
partitionIndices));
+ }
+
+ TestHarnessCallContext callContext = new TestHarnessCallContext();
+ callContext.typeFactory = dataTypeFactory;
+ callContext.argumentDataTypes = argumentDataTypes;
+ callContext.tableSemantics = tableSemanticsMap;
+ callContext.functionDefinition = function;
+ callContext.name = function.getClass().getSimpleName();
+
+ for (Map.Entry<String, StateTypeStrategy> entry :
stateStrategies.entrySet()) {
+ String stateName = entry.getKey();
+ StateTypeStrategy strategy = entry.getValue();
+
+ Optional<DataType> dataTypeOpt =
strategy.inferType(callContext);
+ if (dataTypeOpt.isEmpty()) {
+ throw new IllegalStateException(
+ String.format(
+ "Could not infer data type for state
parameter '%s'",
+ stateName));
+ }
+ DataType stateDataType = dataTypeOpt.get();
+
+ Optional<Duration> ttlOpt =
strategy.getTimeToLive(callContext);
+ stateArguments.add(
+ new StateArgumentInfo(stateName, stateDataType,
ttlOpt.orElse(null)));
+ }
+
+ List<ArgumentInfo> allArguments = new ArrayList<>();
+ allArguments.addAll(stateArguments);
+ allArguments.addAll(tableAndScalarArguments);
- return arguments;
+ return allArguments;
+ }
+
+ /** Creates appropriate StateConverter for the given state data type.
*/
+ private StateConverter createStateConverter(DataType stateDataType,
ClassLoader classLoader)
+ throws Exception {
+ LogicalType logicalType = stateDataType.getLogicalType();
+
+ if (logicalType instanceof ArrayType) {
+ ArrayType arrayType = (ArrayType) logicalType;
+ DataType elementType = stateDataType.getChildren().get(0);
+ DataStructureConverter<Object, Object> elementConverter =
+ DataStructureConverters.getConverter(elementType);
+ elementConverter.open(classLoader);
+ return new ListViewStateConverter(arrayType, elementConverter);
+ } else if (logicalType instanceof MapType) {
+ MapType mapType = (MapType) logicalType;
+ DataType keyType = stateDataType.getChildren().get(0);
+ DataType valueType = stateDataType.getChildren().get(1);
+ DataStructureConverter<Object, Object> keyConverter =
+ DataStructureConverters.getConverter(keyType);
+ DataStructureConverter<Object, Object> valueConverter =
+ DataStructureConverters.getConverter(valueType);
+ keyConverter.open(classLoader);
+ valueConverter.open(classLoader);
+ return new MapViewStateConverter(mapType, keyConverter,
valueConverter);
+ } else if (logicalType instanceof RowType) {
+ RowType rowType = (RowType) logicalType;
+ DataStructureConverter<Object, Object> converter =
+ DataStructureConverters.getConverter(stateDataType);
+ converter.open(classLoader);
+ return new RowStateConverter(converter, rowType);
+ } else {
+ DataStructureConverter<Object, Object> converter =
+ DataStructureConverters.getConverter(stateDataType);
+ converter.open(classLoader);
+ Class<?> stateClass = stateDataType.getConversionClass();
+ return new StructuredTypeStateConverter(converter, stateClass);
+ }
}
/** Checks if an argument name is a system-reserved argument. */
@@ -1034,8 +1282,14 @@ public class ProcessTableFunctionTestHarness<OUT>
implements AutoCloseable {
boolean hasPassColumnsThrough =
staticArg.getTraits().contains(StaticArgumentTrait.PASS_COLUMNS_THROUGH);
- return new ArgumentInfo(
- name, dataType, primaryTrait, partitionColumnNames,
hasPassColumnsThrough);
+ if (primaryTrait == ArgumentTrait.SCALAR) {
+ ScalarArgumentConfiguration config = scalarArgs.get(name);
+ Object value = config != null ? config.value : null;
+ return new ScalarArgumentInfo(name, dataType, value);
+ } else {
+ return new TableArgumentInfo(
+ name, dataType, primaryTrait, partitionColumnNames,
hasPassColumnsThrough);
+ }
}
private ArgumentTrait extractPrimaryTrait(EnumSet<StaticArgumentTrait>
staticTraits) {
@@ -1083,7 +1337,7 @@ public class ProcessTableFunctionTestHarness<OUT>
implements AutoCloseable {
/** Validates scalar argument values are configured and no unknown
arguments exist. */
private void validateArgumentConfiguration(List<ArgumentInfo>
arguments) {
for (ArgumentInfo arg : arguments) {
- if (arg.isScalar && !scalarArgs.containsKey(arg.name)) {
+ if (arg instanceof ScalarArgumentInfo &&
!scalarArgs.containsKey(arg.name)) {
throw new IllegalStateException(
String.format(
"Missing required scalar argument '%s'. "
@@ -1137,22 +1391,20 @@ public class ProcessTableFunctionTestHarness<OUT>
implements AutoCloseable {
ProcessTableFunction<OUT> function,
DataTypeFactory dataTypeFactory,
TypeInference systemTypeInference,
- List<ArgumentInfo> arguments) {
+ List<TableArgumentInfo> arguments) {
List<DataType> argumentDataTypes = new ArrayList<>();
- for (ArgumentInfo arg : arguments) {
+ for (TableArgumentInfo arg : arguments) {
argumentDataTypes.add(arg.dataType);
}
Map<Integer, TableSemantics> tableSemanticsMap = new HashMap<>();
for (int i = 0; i < arguments.size(); i++) {
- ArgumentInfo arg = arguments.get(i);
- if (arg.isTableArgument) {
- int[] partitionIndices = getPartitionColumnIndices(arg);
- TableSemantics semantics =
- new TestHarnessTableSemantics(arg.dataType,
partitionIndices);
- tableSemanticsMap.put(i, semantics);
- }
+ TableArgumentInfo arg = arguments.get(i);
+ int[] partitionIndices = getPartitionColumnIndices(arg);
+ TableSemantics semantics =
+ new TestHarnessTableSemantics(arg.dataType,
partitionIndices);
+ tableSemanticsMap.put(i, semantics);
}
TestHarnessCallContext callContext = new TestHarnessCallContext();
@@ -1174,28 +1426,12 @@ public class ProcessTableFunctionTestHarness<OUT>
implements AutoCloseable {
return outputTypeOpt.get();
}
- private static List<String> extractFieldNames(DataType dataType) {
- LogicalType logicalType = dataType.getLogicalType();
- if (logicalType instanceof RowType) {
- return ((RowType) logicalType).getFieldNames();
- } else if (logicalType instanceof StructuredType) {
- return ((StructuredType) logicalType)
- .getAttributes().stream()
-
.map(StructuredType.StructuredAttribute::getName)
- .collect(java.util.stream.Collectors.toList());
- } else {
- throw new IllegalStateException(
- "Expected RowType or StructuredType, got: "
- + logicalType.getClass().getSimpleName());
- }
- }
-
- private int[] getPartitionColumnIndices(ArgumentInfo arg) {
+ private int[] getPartitionColumnIndices(TableArgumentInfo arg) {
if (arg.partitionColumnNames == null ||
arg.partitionColumnNames.length == 0) {
return new int[0];
}
- List<String> fieldNames = extractFieldNames(arg.dataType);
+ List<String> fieldNames = getFieldNames(arg.dataType);
int[] indices = new int[arg.partitionColumnNames.length];
for (int i = 0; i < arg.partitionColumnNames.length; i++) {
@@ -1252,31 +1488,66 @@ public class ProcessTableFunctionTestHarness<OUT>
implements AutoCloseable {
}
/**
- * Metadata for a single argument extracted from type inference.
+ * Base class for PTF eval() arguments.
*
* <p>Represents validated argument information combining PTF signature,
type inference results,
* and builder configuration.
*/
- private static class ArgumentInfo {
+ private abstract static class ArgumentInfo {
final String name;
final DataType dataType;
+
+ ArgumentInfo(String name, DataType dataType) {
+ this.name = name;
+ this.dataType = dataType;
+ }
+
+ static List<StateArgumentInfo> filterStateArguments(List<ArgumentInfo>
arguments) {
+ return arguments.stream()
+ .filter(arg -> arg instanceof StateArgumentInfo)
+ .map(arg -> (StateArgumentInfo) arg)
+ .collect(Collectors.toList());
+ }
+
+ static List<TableArgumentInfo> filterTableArguments(List<ArgumentInfo>
arguments) {
+ return arguments.stream()
+ .filter(arg -> arg instanceof TableArgumentInfo)
+ .map(arg -> (TableArgumentInfo) arg)
+ .collect(Collectors.toList());
+ }
+
+ static List<ScalarArgumentInfo>
filterScalarArguments(List<ArgumentInfo> arguments) {
+ return arguments.stream()
+ .filter(arg -> arg instanceof ScalarArgumentInfo)
+ .map(arg -> (ScalarArgumentInfo) arg)
+ .collect(Collectors.toList());
+ }
+ }
+
+ /** State parameter with TTL configuration. */
+ static class StateArgumentInfo extends ArgumentInfo {
+ final Duration ttl;
+
+ StateArgumentInfo(String name, DataType dataType, Duration ttl) {
+ super(name, dataType);
+ this.ttl = ttl;
+ }
+ }
+
+ /** Table argument with partitioning and output prepending strategy. */
+ private static class TableArgumentInfo extends ArgumentInfo {
final String[] partitionColumnNames;
- final boolean isScalar;
- final boolean isTableArgument;
final boolean isSetSemantic;
final OutputPrependStrategy prependStrategy;
- ArgumentInfo(
+ TableArgumentInfo(
String name,
DataType dataType,
ArgumentTrait primaryTrait,
String[] partitionColumnNames,
boolean hasPassColumnsThrough) {
- this.name = name;
- this.dataType = dataType;
+ super(name, dataType);
this.partitionColumnNames = partitionColumnNames;
- this.isScalar = (primaryTrait == ArgumentTrait.SCALAR);
- this.isTableArgument = (primaryTrait != ArgumentTrait.SCALAR);
this.isSetSemantic = (primaryTrait ==
ArgumentTrait.SET_SEMANTIC_TABLE);
this.prependStrategy =
hasPassColumnsThrough
@@ -1287,6 +1558,16 @@ public class ProcessTableFunctionTestHarness<OUT>
implements AutoCloseable {
}
}
+ /** Scalar (constant) argument. */
+ private static class ScalarArgumentInfo extends ArgumentInfo {
+ final Object value;
+
+ ScalarArgumentInfo(String name, DataType dataType, Object value) {
+ super(name, dataType);
+ this.value = value;
+ }
+ }
+
private static class TableArgumentConfiguration {
final AbstractDataType<?> explicitType;
@@ -1310,4 +1591,12 @@ public class ProcessTableFunctionTestHarness<OUT>
implements AutoCloseable {
this.columnNames = columnNames;
}
}
+
+ private static class StateArgumentConfiguration {
+ final Map<Row, Object> initialValues;
+
+ StateArgumentConfiguration() {
+ this.initialValues = new HashMap<>();
+ }
+ }
}
diff --git
a/flink-table/flink-table-test-utils/src/main/java/org/apache/flink/table/runtime/functions/RowStateConverter.java
b/flink-table/flink-table-test-utils/src/main/java/org/apache/flink/table/runtime/functions/RowStateConverter.java
new file mode 100644
index 00000000000..f293edb6aae
--- /dev/null
+++
b/flink-table/flink-table-test-utils/src/main/java/org/apache/flink/table/runtime/functions/RowStateConverter.java
@@ -0,0 +1,60 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.runtime.functions;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.table.data.conversion.DataStructureConverter;
+import org.apache.flink.table.types.logical.RowType;
+import org.apache.flink.types.Row;
+
+/** Converter for {@link Row}-typed state. */
+@Internal
+class RowStateConverter implements StateConverter {
+
+ private final DataStructureConverter<Object, Object> converter;
+ private final RowType rowType;
+
+ RowStateConverter(DataStructureConverter<Object, Object> converter,
RowType rowType) {
+ this.converter = converter;
+ this.rowType = rowType;
+ }
+
+ @Override
+ public Object toInternal(Object external) {
+ if (external == null) {
+ return null;
+ }
+ return converter.toInternal(external);
+ }
+
+ @Override
+ public Object toExternal(Object internal) {
+ if (internal == null) {
+ return null;
+ }
+ return converter.toExternal(internal);
+ }
+
+ @Override
+ public Object createNewInternalState() {
+ Row row = Row.withNames();
+ rowType.getFieldNames().forEach(name -> row.setField(name, null));
+ return converter.toInternal(row);
+ }
+}
diff --git
a/flink-table/flink-table-test-utils/src/main/java/org/apache/flink/table/runtime/functions/StateConverter.java
b/flink-table/flink-table-test-utils/src/main/java/org/apache/flink/table/runtime/functions/StateConverter.java
new file mode 100644
index 00000000000..3d580aa77be
--- /dev/null
+++
b/flink-table/flink-table-test-utils/src/main/java/org/apache/flink/table/runtime/functions/StateConverter.java
@@ -0,0 +1,38 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.runtime.functions;
+
+import org.apache.flink.annotation.Internal;
+
+/**
+ * Converter between external state representations (ListView, MapView, Row &
Pojo state) and
+ * internal storage formats (ArrayData, MapData, & RowData).
+ */
+@Internal
+interface StateConverter {
+
+ /** Converts an external state object to internal storage format. */
+ Object toInternal(Object external) throws Exception;
+
+ /** Converts an internal storage format to external state object. */
+ Object toExternal(Object internal);
+
+ /** Create new internal state instance. */
+ Object createNewInternalState();
+}
diff --git
a/flink-table/flink-table-test-utils/src/main/java/org/apache/flink/table/runtime/functions/StructuredTypeStateConverter.java
b/flink-table/flink-table-test-utils/src/main/java/org/apache/flink/table/runtime/functions/StructuredTypeStateConverter.java
new file mode 100644
index 00000000000..5599aef14a7
--- /dev/null
+++
b/flink-table/flink-table-test-utils/src/main/java/org/apache/flink/table/runtime/functions/StructuredTypeStateConverter.java
@@ -0,0 +1,67 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.runtime.functions;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.table.data.conversion.DataStructureConverter;
+
+/**
+ * Converter for value state backed by structured types.
+ *
+ * <p>Converts between external value state objects and internal RowData
representation.
+ */
+@Internal
+class StructuredTypeStateConverter implements StateConverter {
+
+ private final DataStructureConverter<Object, Object> converter;
+ private final Class<?> pojoClass;
+
+ StructuredTypeStateConverter(
+ DataStructureConverter<Object, Object> converter, Class<?>
pojoClass) {
+ this.converter = converter;
+ this.pojoClass = pojoClass;
+ }
+
+ @Override
+ public Object toInternal(Object external) {
+ if (external == null) {
+ return null;
+ }
+ return converter.toInternal(external);
+ }
+
+ @Override
+ public Object toExternal(Object internal) {
+ if (internal == null) {
+ return null;
+ }
+ return converter.toExternal(internal);
+ }
+
+ @Override
+ public Object createNewInternalState() {
+ try {
+ Object newPojo = pojoClass.getDeclaredConstructor().newInstance();
+ return converter.toInternal(newPojo);
+ } catch (Exception e) {
+ throw new RuntimeException(
+ "Failed to create new instance of POJO class: " +
pojoClass.getName(), e);
+ }
+ }
+}
diff --git
a/flink-table/flink-table-test-utils/src/main/java/org/apache/flink/table/runtime/functions/TestHarnessDataTypeFactory.java
b/flink-table/flink-table-test-utils/src/main/java/org/apache/flink/table/runtime/functions/TestHarnessDataTypeFactory.java
index af748db7ad2..ff01478fba1 100644
---
a/flink-table/flink-table-test-utils/src/main/java/org/apache/flink/table/runtime/functions/TestHarnessDataTypeFactory.java
+++
b/flink-table/flink-table-test-utils/src/main/java/org/apache/flink/table/runtime/functions/TestHarnessDataTypeFactory.java
@@ -28,6 +28,7 @@ import org.apache.flink.table.types.UnresolvedDataType;
import org.apache.flink.table.types.extraction.DataTypeExtractor;
import org.apache.flink.table.types.logical.LogicalType;
import org.apache.flink.table.types.logical.utils.LogicalTypeParser;
+import org.apache.flink.table.types.utils.TypeConversions;
import org.apache.flink.table.types.utils.TypeInfoDataTypeConverter;
/**
@@ -53,8 +54,7 @@ class TestHarnessDataTypeFactory implements DataTypeFactory {
public DataType createDataType(String typeString) {
LogicalType logicalType =
LogicalTypeParser.parse(typeString,
Thread.currentThread().getContextClassLoader());
- return
org.apache.flink.table.types.utils.TypeConversions.fromLogicalToDataType(
- logicalType);
+ return TypeConversions.fromLogicalToDataType(logicalType);
}
@Override
diff --git
a/flink-table/flink-table-test-utils/src/main/java/org/apache/flink/table/runtime/functions/TestHarnessStateManager.java
b/flink-table/flink-table-test-utils/src/main/java/org/apache/flink/table/runtime/functions/TestHarnessStateManager.java
new file mode 100644
index 00000000000..35f18caa611
--- /dev/null
+++
b/flink-table/flink-table-test-utils/src/main/java/org/apache/flink/table/runtime/functions/TestHarnessStateManager.java
@@ -0,0 +1,233 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.runtime.functions;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.table.types.logical.LogicalType;
+import org.apache.flink.types.Row;
+
+import javax.annotation.Nullable;
+
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.stream.Collectors;
+
+/**
+ * State manager for {@link ProcessTableFunctionTestHarness}.
+ *
+ * <p>Handles state storage, lifecycle, and conversion between external and
internal storage
+ * formats.
+ */
+@Internal
+class TestHarnessStateManager {
+
+ private final Map<Row, Map<String, Object>> stateByKey = new HashMap<>();
+ private final List<ProcessTableFunctionTestHarness.StateArgumentInfo>
stateArguments;
+ private final Map<String, StateConverter> stateConverters;
+ private final PartitionKeyInfo partitionKeyInfo;
+
+ TestHarnessStateManager(
+ List<ProcessTableFunctionTestHarness.StateArgumentInfo>
stateArguments,
+ Map<String, StateConverter> stateConverters,
+ PartitionKeyInfo partitionKeyInfo) {
+ this.stateArguments = stateArguments;
+ this.stateConverters = stateConverters;
+ this.partitionKeyInfo = partitionKeyInfo;
+ }
+
+ static class PartitionKeyInfo {
+ final int arity;
+ @Nullable final String[] columnNames;
+ @Nullable final Class<?>[] columnTypes;
+
+ PartitionKeyInfo(
+ int arity,
+ @Nullable String[] columnNames,
+ @Nullable LogicalType[] columnLogicalTypes) {
+ this.arity = arity;
+ this.columnNames = columnNames;
+ this.columnTypes =
+ columnLogicalTypes != null
+ ? Arrays.stream(columnLogicalTypes)
+ .map(LogicalType::getDefaultConversion)
+ .toArray(Class<?>[]::new)
+ : null;
+ }
+
+ void validate(Row key) {
+ if (key.getArity() != arity) {
+ throw new IllegalArgumentException(
+ String.format(
+ "Partition key has arity %d, but expected
arity %d.",
+ key.getArity(), arity));
+ }
+ if (columnTypes == null) {
+ return;
+ }
+ for (int i = 0; i < arity; i++) {
+ Object value = key.getField(i);
+ if (value != null && !columnTypes[i].isInstance(value)) {
+ String columnName = columnNames != null ? columnNames[i] :
"position " + i;
+ throw new IllegalArgumentException(
+ String.format(
+ "Partition key has type %s at position %d,
"
+ + "but partition column '%s'
expects %s.",
+ value.getClass().getSimpleName(),
+ i,
+ columnName,
+ columnTypes[i].getSimpleName()));
+ }
+ }
+ }
+ }
+
+ /**
+ * Load state for a partition key. Creates new state instances if none
exist. Converts internal
+ * storage to external objects (value state, ListView, MapView).
+ */
+ Map<String, Object> loadStateForKey(Row key) {
+ Map<String, Object> internalState =
+ stateByKey.computeIfAbsent(key, k -> createEmptyKeyState());
+
+ Map<String, Object> externalState = new HashMap<>();
+ for (ProcessTableFunctionTestHarness.StateArgumentInfo stateArg :
stateArguments) {
+ Object internalData = internalState.get(stateArg.name);
+ Object external = convertToExternal(internalData, stateArg);
+ externalState.put(stateArg.name, external);
+ }
+ return externalState;
+ }
+
+ /**
+ * Update mutated state after eval() invocation. Converts external objects
to internal format.
+ */
+ void updateStateForKey(Row key, Map<String, Object> externalState) throws
Exception {
+ Map<String, Object> internalState = new HashMap<>();
+ for (ProcessTableFunctionTestHarness.StateArgumentInfo stateArg :
stateArguments) {
+ Object external = externalState.get(stateArg.name);
+ Object internalData = convertToInternal(external, stateArg);
+ internalState.put(stateArg.name, internalData);
+ }
+ stateByKey.put(key, internalState);
+ }
+
+ /** Clear all state for a partition key. */
+ void clearAllStatesForKey(Row key) {
+ partitionKeyInfo.validate(key);
+ stateByKey.remove(key);
+ }
+
+ /** Clear specific state entry for a given partition key, resetting it to
its default value. */
+ void clearStateForKey(String stateName, Row key) {
+ partitionKeyInfo.validate(key);
+ Map<String, Object> internalState = stateByKey.get(key);
+ if (internalState != null) {
+ ProcessTableFunctionTestHarness.StateArgumentInfo stateArg =
+ findStateArgument(stateName);
+ internalState.put(stateName, createNewStateInternalData(stateArg));
+ }
+ }
+
+ /** Sets the state for a given partition key. */
+ void setStateForKey(String stateName, Row key, Object externalState)
throws Exception {
+ partitionKeyInfo.validate(key);
+ ProcessTableFunctionTestHarness.StateArgumentInfo stateArg =
findStateArgument(stateName);
+ Object internalData = convertToInternal(externalState, stateArg);
+
+ Map<String, Object> internalState =
+ stateByKey.computeIfAbsent(key, k -> createEmptyKeyState());
+ internalState.put(stateName, internalData);
+ }
+
+ /** Get the state for a given partition key. */
+ @SuppressWarnings("unchecked")
+ <T> T getStateForKey(String stateName, Row key) {
+ partitionKeyInfo.validate(key);
+ Map<String, Object> internalState = stateByKey.get(key);
+ if (internalState == null) {
+ return null;
+ }
+ Object internalData = internalState.get(stateName);
+ if (internalData == null) {
+ return null;
+ }
+ return (T) convertToExternal(internalData,
findStateArgument(stateName));
+ }
+
+ /** Get all partition keys that have a specific state entry. */
+ Set<Row> getKeysForState(String stateName) {
+ return stateByKey.entrySet().stream()
+ .filter(entry -> entry.getValue().containsKey(stateName))
+ .map(Map.Entry::getKey)
+ .collect(Collectors.toSet());
+ }
+
+ /** Get all state values for a state name across all partition keys. */
+ @SuppressWarnings("unchecked")
+ <T> Map<Row, T> getStateForAllKeys(String stateName) {
+ ProcessTableFunctionTestHarness.StateArgumentInfo stateArg =
findStateArgument(stateName);
+ Map<Row, T> result = new HashMap<>();
+ for (Map.Entry<Row, Map<String, Object>> entry :
stateByKey.entrySet()) {
+ Object internalData = entry.getValue().get(stateName);
+ if (internalData != null) {
+ result.put(entry.getKey(), (T) convertToExternal(internalData,
stateArg));
+ }
+ }
+ return result;
+ }
+
+ private Map<String, Object> createEmptyKeyState() {
+ Map<String, Object> newState = new HashMap<>();
+ for (ProcessTableFunctionTestHarness.StateArgumentInfo stateArg :
stateArguments) {
+ newState.put(stateArg.name, createNewStateInternalData(stateArg));
+ }
+ return newState;
+ }
+
+ private Object createNewStateInternalData(
+ ProcessTableFunctionTestHarness.StateArgumentInfo stateArg) {
+ return stateConverters.get(stateArg.name).createNewInternalState();
+ }
+
+ private Object convertToExternal(
+ Object internalData,
ProcessTableFunctionTestHarness.StateArgumentInfo stateArg) {
+ return stateConverters.get(stateArg.name).toExternal(internalData);
+ }
+
+ private Object convertToInternal(
+ Object external, ProcessTableFunctionTestHarness.StateArgumentInfo
stateArg)
+ throws Exception {
+ return stateConverters.get(stateArg.name).toInternal(external);
+ }
+
+ private ProcessTableFunctionTestHarness.StateArgumentInfo
findStateArgument(String stateName) {
+ for (ProcessTableFunctionTestHarness.StateArgumentInfo stateArg :
stateArguments) {
+ if (stateArg.name.equals(stateName)) {
+ return stateArg;
+ }
+ }
+ String available =
+ stateArguments.stream().map(arg ->
arg.name).collect(Collectors.joining(", "));
+ throw new IllegalArgumentException(
+ "Unknown state: '" + stateName + "'. Available states: [" +
available + "]");
+ }
+}
diff --git
a/flink-table/flink-table-test-utils/src/test/java/org/apache/flink/table/runtime/functions/ProcessTableFunctionTestHarnessTest.java
b/flink-table/flink-table-test-utils/src/test/java/org/apache/flink/table/runtime/functions/ProcessTableFunctionTestHarnessTest.java
index a02e9a8eefd..88e1c8b32d8 100644
---
a/flink-table/flink-table-test-utils/src/test/java/org/apache/flink/table/runtime/functions/ProcessTableFunctionTestHarnessTest.java
+++
b/flink-table/flink-table-test-utils/src/test/java/org/apache/flink/table/runtime/functions/ProcessTableFunctionTestHarnessTest.java
@@ -24,13 +24,19 @@ import org.apache.flink.table.annotation.DataTypeHint;
import org.apache.flink.table.annotation.StateHint;
import org.apache.flink.table.api.DataTypes;
import org.apache.flink.table.api.ValidationException;
+import org.apache.flink.table.api.dataview.ListView;
+import org.apache.flink.table.api.dataview.MapView;
import org.apache.flink.table.functions.ProcessTableFunction;
import org.apache.flink.types.Row;
import org.apache.flink.types.RowKind;
import org.junit.jupiter.api.Test;
+import java.util.ArrayList;
import java.util.List;
+import java.util.Map;
+import java.util.Objects;
+import java.util.Set;
import static org.assertj.core.api.Assertions.assertThat;
import static org.junit.jupiter.api.Assertions.assertThrows;
@@ -136,6 +142,25 @@ class ProcessTableFunctionTestHarnessTest {
}
}
+ /** Stateful PTF with OPTIONAL_PARTITION_BY. */
+ @DataTypeHint("ROW<count BIGINT>")
+ public static class StatefulOptionalPartitionPTF extends
ProcessTableFunction<Row> {
+ public static class CounterState {
+ public long counter = 0L;
+ }
+
+ public void eval(
+ @StateHint CounterState state,
+ @ArgumentHint({
+ ArgumentTrait.SET_SEMANTIC_TABLE,
+ ArgumentTrait.OPTIONAL_PARTITION_BY
+ })
+ Row input) {
+ state.counter++;
+ collect(Row.of(state.counter));
+ }
+ }
+
/** Simple POJO for testing structured type input/output. */
public static class User {
public String name;
@@ -162,12 +187,12 @@ class ProcessTableFunctionTestHarnessTest {
return false;
}
User user = (User) o;
- return age == user.age && java.util.Objects.equals(name,
user.name);
+ return age == user.age && Objects.equals(name, user.name);
}
@Override
public int hashCode() {
- return java.util.Objects.hash(name, age);
+ return Objects.hash(name, age);
}
}
@@ -282,17 +307,94 @@ class ProcessTableFunctionTestHarnessTest {
}
}
- /** PTF with State parameter - should be rejected by test harness. */
- @DataTypeHint("ROW<value INT>")
- public static class PTFWithState extends ProcessTableFunction<Row> {
- public static class CountState {
+ /** PTF with simple value state - counts rows per partition. */
+ @DataTypeHint("ROW<count BIGINT>")
+ public static class PTFWithValueState extends ProcessTableFunction<Row> {
+ public static class CounterState {
public long counter = 0L;
}
public void eval(
- @StateHint CountState state,
- @ArgumentHint(ArgumentTrait.ROW_SEMANTIC_TABLE) Row input) {
- collect(input);
+ @StateHint CounterState state,
+ @ArgumentHint(ArgumentTrait.SET_SEMANTIC_TABLE) Row input) {
+ state.counter++;
+ collect(Row.of(state.counter));
+ }
+ }
+
+ /** PTF with ListView state - accumulates values in a list. */
+ @DataTypeHint("ROW<values ARRAY<INT>>")
+ public static class PTFWithListViewState extends ProcessTableFunction<Row>
{
+ public void eval(
+ @StateHint(type = @DataTypeHint("ARRAY<INT>"))
ListView<Integer> listState,
+ @ArgumentHint(ArgumentTrait.SET_SEMANTIC_TABLE) Row input)
+ throws Exception {
+ Integer value = input.getFieldAs("value");
+ listState.add(value);
+
+ // Collect all values as an array
+ List<Integer> values = new ArrayList<>();
+ for (Integer v : listState.get()) {
+ values.add(v);
+ }
+ collect(Row.of((Object) values.toArray(new Integer[0])));
+ }
+ }
+
+ /** PTF with MapView state - counts occurrences of each key. */
+ @DataTypeHint("ROW<key STRING, count INT>")
+ public static class PTFWithMapViewState extends ProcessTableFunction<Row> {
+ public void eval(
+ @StateHint MapView<String, Integer> mapState,
+ @ArgumentHint(ArgumentTrait.SET_SEMANTIC_TABLE) Row input)
+ throws Exception {
+ String key = input.getFieldAs("key");
+ Integer count = mapState.get(key);
+ if (count == null) {
+ mapState.put(key, 1);
+ } else {
+ mapState.put(key, count + 1);
+ }
+ collect(Row.of(key, mapState.get(key)));
+ }
+ }
+
+ /** PTF with Row state - mirrors the doc example using Row as state type.
*/
+ @DataTypeHint("ROW<count BIGINT>")
+ public static class PTFWithRowState extends ProcessTableFunction<Row> {
+ public void eval(
+ @StateHint(type = @DataTypeHint("ROW<count BIGINT>")) Row
memory,
+ @ArgumentHint(ArgumentTrait.SET_SEMANTIC_TABLE) Row input) {
+ Long newCount = 1L;
+ if (memory.getField("count") != null) {
+ newCount += memory.<Long>getFieldAs("count");
+ }
+ memory.setField("count", newCount);
+ collect(Row.of(newCount));
+ }
+ }
+
+ /** PTF with both value state and ListView state. */
+ @DataTypeHint("ROW<count BIGINT, sum INT>")
+ public static class PTFWithMultipleStates extends
ProcessTableFunction<Row> {
+ public static class CounterState {
+ public long count = 0L;
+ }
+
+ public void eval(
+ @StateHint CounterState counter,
+ @StateHint(type = @DataTypeHint("ARRAY<INT>"))
ListView<Integer> history,
+ @ArgumentHint(ArgumentTrait.SET_SEMANTIC_TABLE) Row input)
+ throws Exception {
+ Integer value = input.getFieldAs("value");
+ counter.count++;
+ history.add(value);
+
+ int sum = 0;
+ for (Integer v : history.get()) {
+ sum += v;
+ }
+ collect(Row.of(counter.count, sum));
}
}
@@ -597,6 +699,92 @@ class ProcessTableFunctionTestHarnessTest {
}
}
+ @Test
+ void testOptionalPartitionByWithStateNoPartition() throws Exception {
+ try (ProcessTableFunctionTestHarness<Row> harness =
+
ProcessTableFunctionTestHarness.ofClass(StatefulOptionalPartitionPTF.class)
+ .withTableArgument("input", DataTypes.of("ROW<key
STRING, value INT>"))
+ .build()) {
+
+ harness.processElement(Row.of("A", 10));
+ harness.processElement(Row.of("B", 20));
+ harness.processElement(Row.of("A", 30));
+
+ List<Row> output = harness.getOutput();
+ assertThat(output).hasSize(3);
+ assertThat(output.get(0)).isEqualTo(Row.of(1L));
+ assertThat(output.get(1)).isEqualTo(Row.of(2L));
+ assertThat(output.get(2)).isEqualTo(Row.of(3L));
+
+ StatefulOptionalPartitionPTF.CounterState state =
+ harness.getStateForKey("state", Row.of());
+ assertThat(state.counter).isEqualTo(3L);
+ }
+ }
+
+ @Test
+ void testOptionalPartitionByWithStateAndPartition() throws Exception {
+ try (ProcessTableFunctionTestHarness<Row> harness =
+
ProcessTableFunctionTestHarness.ofClass(StatefulOptionalPartitionPTF.class)
+ .withTableArgument("input", DataTypes.of("ROW<key
STRING, value INT>"))
+ .withPartitionBy("input", "key")
+ .build()) {
+
+ harness.processElement(Row.of("A", 10));
+ harness.processElement(Row.of("B", 20));
+ harness.processElement(Row.of("A", 30));
+
+ List<Row> output = harness.getOutput();
+ assertThat(output).hasSize(3);
+ assertThat(output.get(0)).isEqualTo(Row.of("A", 1L));
+ assertThat(output.get(1)).isEqualTo(Row.of("B", 1L));
+ assertThat(output.get(2)).isEqualTo(Row.of("A", 2L));
+
+ StatefulOptionalPartitionPTF.CounterState stateA =
+ harness.getStateForKey("state", Row.of("A"));
+ StatefulOptionalPartitionPTF.CounterState stateB =
+ harness.getStateForKey("state", Row.of("B"));
+ assertThat(stateA.counter).isEqualTo(2L);
+ assertThat(stateB.counter).isEqualTo(1L);
+ }
+ }
+
+ @Test
+ void testOptionalPartitionByWithInitialStateNoPartition() throws Exception
{
+ StatefulOptionalPartitionPTF.CounterState initialState =
+ new StatefulOptionalPartitionPTF.CounterState();
+ initialState.counter = 10L;
+
+ try (ProcessTableFunctionTestHarness<Row> harness =
+
ProcessTableFunctionTestHarness.ofClass(StatefulOptionalPartitionPTF.class)
+ .withTableArgument("input", DataTypes.of("ROW<key
STRING, value INT>"))
+ .withInitialStateForKey("state", Row.of(),
initialState)
+ .build()) {
+
+ harness.processElement(Row.of("A", 1));
+
+ StatefulOptionalPartitionPTF.CounterState state =
+ harness.getStateForKey("state", Row.of());
+ assertThat(state.counter).isEqualTo(11L);
+
+ StatefulOptionalPartitionPTF.CounterState newState =
+ new StatefulOptionalPartitionPTF.CounterState();
+ newState.counter = 50L;
+ harness.setStateForKey("state", Row.of(), newState);
+ state = harness.getStateForKey("state", Row.of());
+ assertThat(state.counter).isEqualTo(50L);
+
+ harness.clearStateForKey("state", Row.of());
+ state = harness.getStateForKey("state", Row.of());
+ assertThat(state.counter).isEqualTo(0L);
+
+ harness.processElement(Row.of("B", 2));
+ harness.clearAllStatesForKey(Row.of());
+ state = harness.getStateForKey("state", Row.of());
+ assertThat(state).isNull();
+ }
+ }
+
//
-------------------------------------------------------------------------
// Data Type Conversion Tests
//
-------------------------------------------------------------------------
@@ -902,7 +1090,7 @@ class ProcessTableFunctionTestHarnessTest {
// Verify that PASS_COLUMNS_THROUGH is rejected when used with
multiple table arguments
Exception exception =
assertThrows(
- org.apache.flink.table.api.ValidationException.class,
+ ValidationException.class,
() -> {
ProcessTableFunctionTestHarness.ofClass(
InvalidPassColumnsThroughMultiTablePTF.class)
@@ -998,22 +1186,6 @@ class ProcessTableFunctionTestHarnessTest {
.contains("position 0");
}
- @Test
- void testStateParameterRejected() {
- Exception exception =
- assertThrows(
- IllegalStateException.class,
- () ->
-
ProcessTableFunctionTestHarness.ofClass(PTFWithState.class)
- .withTableArgument("input",
DataTypes.of("ROW<value INT>"))
- .build());
-
- assertThat(exception.getMessage())
- .contains("does not yet support state parameters")
- .contains("@StateHint parameter")
- .contains("position 0");
- }
-
@Test
void testSetSemanticMissingPartitionConfigThrows() {
Exception exception =
@@ -1062,4 +1234,498 @@ class ProcessTableFunctionTestHarnessTest {
assertThat(exception.getMessage()).contains("Partition config already
exists");
}
+
+ //
-------------------------------------------------------------------------
+ // State Tests
+ //
-------------------------------------------------------------------------
+
+ @Test
+ void testValueState() throws Exception {
+ ProcessTableFunctionTestHarness<Row> harness =
+
ProcessTableFunctionTestHarness.ofClass(PTFWithValueState.class)
+ .withTableArgument("input", DataTypes.of("ROW<name
STRING, value INT>"))
+ .withPartitionBy("input", "name")
+ .build();
+
+ harness.processElementForTable("input", Row.of("Alice", 10));
+ assertThat(harness.getOutput()).containsExactly(Row.of("Alice", 1L));
+
+ PTFWithValueState.CounterState state = harness.getStateForKey("state",
Row.of("Alice"));
+ assertThat(state.counter).isEqualTo(1L);
+
+ harness.processElementForTable("input", Row.of("Alice", 15));
+ assertThat(harness.getOutput().get(1)).isEqualTo(Row.of("Alice", 2L));
+
+ state = harness.getStateForKey("state", Row.of("Alice"));
+ assertThat(state.counter).isEqualTo(2L);
+
+ harness.close();
+ }
+
+ @Test
+ void testValueStatePartitionIsolation() throws Exception {
+ ProcessTableFunctionTestHarness<Row> harness =
+
ProcessTableFunctionTestHarness.ofClass(PTFWithValueState.class)
+ .withTableArgument("input", DataTypes.of("ROW<name
STRING, value INT>"))
+ .withPartitionBy("input", "name")
+ .build();
+
+ harness.processElementForTable("input", Row.of("Alice", 10));
+ harness.processElementForTable("input", Row.of("Bob", 20));
+ harness.processElementForTable("input", Row.of("Alice", 15));
+
+ PTFWithValueState.CounterState aliceState =
+ harness.getStateForKey("state", Row.of("Alice"));
+ PTFWithValueState.CounterState bobState =
harness.getStateForKey("state", Row.of("Bob"));
+
+ assertThat(aliceState.counter).isEqualTo(2L);
+ assertThat(bobState.counter).isEqualTo(1L);
+
+ harness.close();
+ }
+
+ @Test
+ void testValueStateWithInitialState() throws Exception {
+ PTFWithValueState.CounterState initialState = new
PTFWithValueState.CounterState();
+ initialState.counter = 100L;
+
+ ProcessTableFunctionTestHarness<Row> harness =
+
ProcessTableFunctionTestHarness.ofClass(PTFWithValueState.class)
+ .withTableArgument("input", DataTypes.of("ROW<id
INT>"))
+ .withPartitionBy("input", "id")
+ .withInitialStateForKey("state", Row.of(1),
initialState)
+ .build();
+
+ PTFWithValueState.CounterState state = harness.getStateForKey("state",
Row.of(1));
+ assertThat(state.counter).isEqualTo(100L);
+
+ harness.processElement(Row.of(1));
+ assertThat(harness.getOutput()).containsExactly(Row.of(1, 101L));
+
+ harness.processElement(Row.of(2));
+ assertThat(harness.getOutput().get(1)).isEqualTo(Row.of(2, 1L));
+
+ harness.close();
+ }
+
+ @Test
+ void testGetStateKeys() throws Exception {
+ ProcessTableFunctionTestHarness<Row> harness =
+
ProcessTableFunctionTestHarness.ofClass(PTFWithValueState.class)
+ .withTableArgument("input", DataTypes.of("ROW<name
STRING, value INT>"))
+ .withPartitionBy("input", "name")
+ .build();
+
+ harness.processElementForTable("input", Row.of("Alice", 10));
+ harness.processElementForTable("input", Row.of("Bob", 20));
+ harness.processElementForTable("input", Row.of("Charlie", 30));
+
+ Set<Row> keys = harness.getKeysForState("state");
+ assertThat(keys)
+ .containsExactlyInAnyOrder(Row.of("Alice"), Row.of("Bob"),
Row.of("Charlie"));
+
+ harness.close();
+ }
+
+ @Test
+ void testGetAllState() throws Exception {
+ ProcessTableFunctionTestHarness<Row> harness =
+
ProcessTableFunctionTestHarness.ofClass(PTFWithValueState.class)
+ .withTableArgument("input", DataTypes.of("ROW<name
STRING, value INT>"))
+ .withPartitionBy("input", "name")
+ .build();
+
+ harness.processElementForTable("input", Row.of("Alice", 10));
+ harness.processElementForTable("input", Row.of("Alice", 15));
+ harness.processElementForTable("input", Row.of("Bob", 20));
+
+ Map<Row, PTFWithValueState.CounterState> allState =
harness.getStateForAllKeys("state");
+
+ assertThat(allState).hasSize(2);
+ assertThat(allState.get(Row.of("Alice")).counter).isEqualTo(2L);
+ assertThat(allState.get(Row.of("Bob")).counter).isEqualTo(1L);
+
+ harness.close();
+ }
+
+ @Test
+ void testListViewState() throws Exception {
+ ProcessTableFunctionTestHarness<Row> harness =
+
ProcessTableFunctionTestHarness.ofClass(PTFWithListViewState.class)
+ .withTableArgument("input", DataTypes.of("ROW<key
STRING, value INT>"))
+ .withPartitionBy("input", "key")
+ .build();
+
+ harness.processElementForTable("input", Row.of("A", 1));
+ assertThat(harness.getOutput()).containsExactly(Row.of("A", new
Integer[] {1}));
+
+ harness.processElementForTable("input", Row.of("A", 2));
+ assertThat(harness.getOutput().get(1)).isEqualTo(Row.of("A", new
Integer[] {1, 2}));
+
+ ListView<Integer> listState = harness.getStateForKey("listState",
Row.of("A"));
+ assertThat(listState.get()).containsExactly(1, 2);
+
+ harness.close();
+ }
+
+ @Test
+ void testMapViewState() throws Exception {
+ ProcessTableFunctionTestHarness<Row> harness =
+
ProcessTableFunctionTestHarness.ofClass(PTFWithMapViewState.class)
+ .withTableArgument(
+ "input", DataTypes.of("ROW<partition STRING,
key STRING>"))
+ .withPartitionBy("input", "partition")
+ .build();
+
+ harness.processElementForTable("input", Row.of("P1", "foo"));
+ assertThat(harness.getOutput()).containsExactly(Row.of("P1", "foo",
1));
+
+ harness.processElementForTable("input", Row.of("P1", "foo"));
+ assertThat(harness.getOutput().get(1)).isEqualTo(Row.of("P1", "foo",
2));
+
+ harness.processElementForTable("input", Row.of("P1", "bar"));
+ assertThat(harness.getOutput().get(2)).isEqualTo(Row.of("P1", "bar",
1));
+
+ MapView<String, Integer> mapState = harness.getStateForKey("mapState",
Row.of("P1"));
+ assertThat(mapState.get("foo")).isEqualTo(2);
+ assertThat(mapState.get("bar")).isEqualTo(1);
+
+ harness.close();
+ }
+
+ @Test
+ void testRowState() throws Exception {
+ ProcessTableFunctionTestHarness<Row> harness =
+ ProcessTableFunctionTestHarness.ofClass(PTFWithRowState.class)
+ .withTableArgument("input", DataTypes.of("ROW<name
STRING, value INT>"))
+ .withPartitionBy("input", "name")
+ .build();
+
+ harness.processElementForTable("input", Row.of("Alice", 10));
+ assertThat(harness.getOutput()).containsExactly(Row.of("Alice", 1L));
+
+ harness.processElementForTable("input", Row.of("Alice", 20));
+ assertThat(harness.getOutput().get(1)).isEqualTo(Row.of("Alice", 2L));
+
+ Row state = harness.getStateForKey("memory", Row.of("Alice"));
+ assertThat((Long) state.getFieldAs("count")).isEqualTo(2L);
+
+ harness.close();
+ }
+
+ @Test
+ void testEmptyState() throws Exception {
+ ProcessTableFunctionTestHarness<Row> harness =
+
ProcessTableFunctionTestHarness.ofClass(PTFWithValueState.class)
+ .withTableArgument("input", DataTypes.of("ROW<name
STRING, value INT>"))
+ .withPartitionBy("input", "name")
+ .build();
+
+ PTFWithValueState.CounterState state = harness.getStateForKey("state",
Row.of("Alice"));
+
+ assertThat(state).isNull();
+
+ harness.close();
+ }
+
+ @Test
+ void testClearAllStatesForKey() throws Exception {
+ ProcessTableFunctionTestHarness<Row> harness =
+
ProcessTableFunctionTestHarness.ofClass(PTFWithValueState.class)
+ .withTableArgument("input", DataTypes.of("ROW<name
STRING, value INT>"))
+ .withPartitionBy("input", "name")
+ .build();
+
+ harness.processElementForTable("input", Row.of("Alice", 10));
+ harness.processElementForTable("input", Row.of("Alice", 15));
+
+ PTFWithValueState.CounterState state = harness.getStateForKey("state",
Row.of("Alice"));
+ assertThat(state.counter).isEqualTo(2L);
+
+ harness.clearAllStatesForKey(Row.of("Alice"));
+
+ state = harness.getStateForKey("state", Row.of("Alice"));
+ assertThat(state).isNull();
+
+ harness.processElementForTable("input", Row.of("Alice", 30));
+ state = harness.getStateForKey("state", Row.of("Alice"));
+ assertThat(state.counter).isEqualTo(1L);
+
+ harness.close();
+ }
+
+ @Test
+ void testClearStateForKey() throws Exception {
+ ProcessTableFunctionTestHarness<Row> harness =
+
ProcessTableFunctionTestHarness.ofClass(PTFWithValueState.class)
+ .withTableArgument("input", DataTypes.of("ROW<name
STRING, value INT>"))
+ .withPartitionBy("input", "name")
+ .build();
+
+ harness.processElementForTable("input", Row.of("Alice", 10));
+ harness.processElementForTable("input", Row.of("Alice", 15));
+
+ PTFWithValueState.CounterState state = harness.getStateForKey("state",
Row.of("Alice"));
+ assertThat(state.counter).isEqualTo(2L);
+
+ harness.clearStateForKey("state", Row.of("Alice"));
+
+ state = harness.getStateForKey("state", Row.of("Alice"));
+ assertThat(state.counter).isEqualTo(0L);
+
+ harness.processElementForTable("input", Row.of("Alice", 30));
+ state = harness.getStateForKey("state", Row.of("Alice"));
+ assertThat(state.counter).isEqualTo(1L);
+
+ harness.close();
+ }
+
+ @Test
+ void testMultipleStateParameters() throws Exception {
+ ProcessTableFunctionTestHarness<Row> harness =
+
ProcessTableFunctionTestHarness.ofClass(PTFWithMultipleStates.class)
+ .withTableArgument("input", DataTypes.of("ROW<key
STRING, value INT>"))
+ .withPartitionBy("input", "key")
+ .build();
+
+ harness.processElementForTable("input", Row.of("A", 10));
+ harness.processElementForTable("input", Row.of("A", 20));
+ harness.processElementForTable("input", Row.of("B", 5));
+
+ assertThat(harness.getOutput())
+ .containsExactly(Row.of("A", 1L, 10), Row.of("A", 2L, 30),
Row.of("B", 1L, 5));
+
+ PTFWithMultipleStates.CounterState counterA =
+ harness.getStateForKey("counter", Row.of("A"));
+ assertThat(counterA.count).isEqualTo(2L);
+
+ ListView<Integer> historyA = harness.getStateForKey("history",
Row.of("A"));
+ assertThat(historyA.get()).containsExactly(10, 20);
+
+ harness.close();
+ }
+
+ @Test
+ void testInitialStateWithListView() throws Exception {
+ ListView<Integer> initialList = new ListView<>();
+ initialList.add(100);
+ initialList.add(200);
+
+ ProcessTableFunctionTestHarness<Row> harness =
+
ProcessTableFunctionTestHarness.ofClass(PTFWithListViewState.class)
+ .withTableArgument("input", DataTypes.of("ROW<key
STRING, value INT>"))
+ .withPartitionBy("input", "key")
+ .withInitialStateForKey("listState", Row.of("A"),
initialList)
+ .build();
+
+ ListView<Integer> listState = harness.getStateForKey("listState",
Row.of("A"));
+ assertThat(listState.get()).containsExactly(100, 200);
+
+ harness.processElementForTable("input", Row.of("A", 3));
+ assertThat(harness.getOutput()).containsExactly(Row.of("A", new
Integer[] {100, 200, 3}));
+
+ harness.close();
+ }
+
+ @Test
+ void testInitialStateWithMapView() throws Exception {
+ MapView<String, Integer> initialMap = new MapView<>();
+ initialMap.put("existing", 42);
+
+ ProcessTableFunctionTestHarness<Row> harness =
+
ProcessTableFunctionTestHarness.ofClass(PTFWithMapViewState.class)
+ .withTableArgument(
+ "input", DataTypes.of("ROW<partition STRING,
key STRING>"))
+ .withPartitionBy("input", "partition")
+ .withInitialStateForKey("mapState", Row.of("P1"),
initialMap)
+ .build();
+
+ MapView<String, Integer> mapState = harness.getStateForKey("mapState",
Row.of("P1"));
+ assertThat(mapState.get("existing")).isEqualTo(42);
+
+ harness.processElementForTable("input", Row.of("P1", "existing"));
+ assertThat(harness.getOutput()).containsExactly(Row.of("P1",
"existing", 43));
+
+ harness.close();
+ }
+
+ @Test
+ void testInitialStateKeyArityMismatch() {
+ Exception exception =
+ assertThrows(
+ IllegalArgumentException.class,
+ () ->
+
ProcessTableFunctionTestHarness.ofClass(PTFWithValueState.class)
+ .withTableArgument(
+ "input",
+ DataTypes.of("ROW<name STRING,
value INT>"))
+ .withPartitionBy("input", "name")
+ .withInitialStateForKey(
+ "state",
+ Row.of("Alice", 42),
+ new
PTFWithValueState.CounterState())
+ .build());
+
+ assertThat(exception.getMessage()).contains("state");
+ assertThat(exception.getMessage()).contains("arity 2");
+ assertThat(exception.getMessage()).contains("arity 1");
+ }
+
+ @Test
+ void testInitialStateKeyTypeMismatch() {
+ Exception exception =
+ assertThrows(
+ IllegalArgumentException.class,
+ () ->
+
ProcessTableFunctionTestHarness.ofClass(PTFWithValueState.class)
+ .withTableArgument(
+ "input",
+ DataTypes.of("ROW<name STRING,
value INT>"))
+ .withPartitionBy("input", "name")
+ .withInitialStateForKey(
+ "state",
+ Row.of(42),
+ new
PTFWithValueState.CounterState())
+ .build());
+
+ assertThat(exception.getMessage()).contains("state");
+ assertThat(exception.getMessage()).contains("Integer");
+ assertThat(exception.getMessage()).contains("name");
+ assertThat(exception.getMessage()).contains("String");
+ }
+
+ @Test
+ void testSetStateForKey() throws Exception {
+ ProcessTableFunctionTestHarness<Row> harness =
+
ProcessTableFunctionTestHarness.ofClass(PTFWithValueState.class)
+ .withTableArgument("input", DataTypes.of("ROW<name
STRING, value INT>"))
+ .withPartitionBy("input", "name")
+ .build();
+
+ harness.processElementForTable("input", Row.of("Alice", 10));
+ harness.processElementForTable("input", Row.of("Alice", 20));
+
+ PTFWithValueState.CounterState state = harness.getStateForKey("state",
Row.of("Alice"));
+ assertThat(state.counter).isEqualTo(2L);
+
+ PTFWithValueState.CounterState newState = new
PTFWithValueState.CounterState();
+ newState.counter = 50L;
+ harness.setStateForKey("state", Row.of("Alice"), newState);
+
+ state = harness.getStateForKey("state", Row.of("Alice"));
+ assertThat(state.counter).isEqualTo(50L);
+
+ harness.processElementForTable("input", Row.of("Alice", 30));
+ assertThat(harness.getOutput().get(2)).isEqualTo(Row.of("Alice", 51L));
+
+ harness.close();
+ }
+
+ @Test
+ void testInvalidStateNameInWithInitialState() {
+ Exception exception =
+ assertThrows(
+ IllegalArgumentException.class,
+ () ->
+
ProcessTableFunctionTestHarness.ofClass(PTFWithValueState.class)
+ .withTableArgument("input",
DataTypes.of("ROW<id INT>"))
+ .withPartitionBy("input", "id")
+ .withInitialStateForKey(
+ "nonExistentState", Row.of(1),
"value")
+ .build());
+
+ assertThat(exception.getMessage()).contains("Unknown state");
+ assertThat(exception.getMessage()).contains("nonExistentState");
+ assertThat(exception.getMessage()).contains("Available states");
+ assertThat(exception.getMessage()).contains("state");
+ }
+
+ //
-------------------------------------------------------------------------
+ // Partition Key Validation Tests
+ //
-------------------------------------------------------------------------
+
+ @Test
+ void testPartitionKeyValidationWrongArity() throws Exception {
+ ProcessTableFunctionTestHarness<Row> harness =
+
ProcessTableFunctionTestHarness.ofClass(PTFWithValueState.class)
+ .withTableArgument("input", DataTypes.of("ROW<name
STRING, value INT>"))
+ .withPartitionBy("input", "name")
+ .build();
+
+ Exception exception =
+ assertThrows(
+ IllegalArgumentException.class,
+ () -> harness.getStateForKey("state", Row.of("Alice",
"extra")));
+ assertThat(exception.getMessage()).contains("arity 2");
+ assertThat(exception.getMessage()).contains("expected arity 1");
+
+ harness.close();
+ }
+
+ @Test
+ void testPartitionKeyValidationWrongType() throws Exception {
+ ProcessTableFunctionTestHarness<Row> harness =
+
ProcessTableFunctionTestHarness.ofClass(PTFWithValueState.class)
+ .withTableArgument("input", DataTypes.of("ROW<name
STRING, value INT>"))
+ .withPartitionBy("input", "name")
+ .build();
+
+ Exception exception =
+ assertThrows(
+ IllegalArgumentException.class,
+ () -> harness.getStateForKey("state", Row.of(123)));
+ assertThat(exception.getMessage()).contains("Integer");
+ assertThat(exception.getMessage()).contains("name");
+ assertThat(exception.getMessage()).contains("String");
+
+ harness.close();
+ }
+
+ @Test
+ void testPartitionKeyValidationOnSetState() throws Exception {
+ ProcessTableFunctionTestHarness<Row> harness =
+
ProcessTableFunctionTestHarness.ofClass(PTFWithValueState.class)
+ .withTableArgument("input", DataTypes.of("ROW<name
STRING, value INT>"))
+ .withPartitionBy("input", "name")
+ .build();
+
+ PTFWithValueState.CounterState state = new
PTFWithValueState.CounterState();
+ state.counter = 1L;
+
+ assertThrows(
+ IllegalArgumentException.class,
+ () -> harness.setStateForKey("state", Row.of(1, 2), state));
+
+ harness.close();
+ }
+
+ @Test
+ void testPartitionKeyValidationOnClearAllStates() throws Exception {
+ ProcessTableFunctionTestHarness<Row> harness =
+
ProcessTableFunctionTestHarness.ofClass(PTFWithValueState.class)
+ .withTableArgument("input", DataTypes.of("ROW<name
STRING, value INT>"))
+ .withPartitionBy("input", "name")
+ .build();
+
+ assertThrows(
+ IllegalArgumentException.class,
+ () -> harness.clearAllStatesForKey(Row.of("a", "b")));
+
+ harness.close();
+ }
+
+ @Test
+ void testPartitionKeyValidationOnClearState() throws Exception {
+ ProcessTableFunctionTestHarness<Row> harness =
+
ProcessTableFunctionTestHarness.ofClass(PTFWithValueState.class)
+ .withTableArgument("input", DataTypes.of("ROW<name
STRING, value INT>"))
+ .withPartitionBy("input", "name")
+ .build();
+
+ assertThrows(
+ IllegalArgumentException.class,
+ () -> harness.clearStateForKey("state", Row.of(42)));
+
+ harness.close();
+ }
}