tsreaper commented on code in PR #121: URL: https://github.com/apache/flink-table-store/pull/121#discussion_r884370488
########## flink-table-store-connector/src/test/java/org/apache/flink/table/store/connector/AggregationITCase.java: ########## @@ -0,0 +1,206 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.store.connector; + +import org.apache.flink.types.Row; + +import org.junit.Test; + +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.ExecutionException; + +import static org.apache.flink.util.CollectionUtil.iteratorToList; +import static org.assertj.core.api.Assertions.assertThat; + +/** ITCase for partial update. */ +public class AggregationITCase extends FileStoreTableITCase { + + @Override + protected List<String> ddl() { + + String ddl1 = + "CREATE TABLE IF NOT EXISTS T3 ( " + + " a STRING, " + + " b BIGINT, " + + " c INT, " + + " PRIMARY KEY (a) NOT ENFORCED )" + + " WITH (" + + " 'merge-engine'='aggregation' ," + + " 'b.aggregate-function'='sum' ," + + " 'c.aggregate-function'='sum' " + + " );"; + String ddl2 = + "CREATE TABLE IF NOT EXISTS T4 ( " + + " a STRING," + + " b INT," + + " c DOUBLE," + + " PRIMARY KEY (a, b) NOT ENFORCED )" + + " WITH (" + + " 'merge-engine'='aggregation'," + + " 'c.aggregate-function' = 'sum'" + + " );"; + String ddl3 = + "CREATE TABLE IF NOT EXISTS T5 ( " + + " a STRING," + + " b INT," + + " c DOUBLE," + + " PRIMARY KEY (a) NOT ENFORCED )" + + " WITH (" + + " 'merge-engine'='aggregation'," + + " 'b.aggregate-function' = 'sum'" + + " );"; + List<String> lists = new ArrayList<>(); + lists.add(ddl1); + lists.add(ddl2); + lists.add(ddl3); + return lists; + } + + @Test + public void testCreateAggregateFunction() throws ExecutionException, InterruptedException { + List<Row> result; + + // T5 + try { + bEnv.executeSql("INSERT INTO T5 VALUES " + "('pk1',1, 2.0), " + "('pk1',1, 2.0)") + .await(); + throw new AssertionError("create table T5 should failed"); + } catch (IllegalArgumentException e) { + assert ("should set aggregate function for every column not part of primary key" + .equals(e.getLocalizedMessage())); + } + } + + @Test + public void testMergeInMemory() throws ExecutionException, InterruptedException { + List<Row> result; + // T3 + bEnv.executeSql("INSERT INTO T3 VALUES " + "('pk1',1, 2), " + "('pk1',1, 2)").await(); + result = iteratorToList(bEnv.from("T3").execute().collect()); + assertThat(result).containsExactlyInAnyOrder(Row.of("pk1", 2L, 4)); + + // T4 + bEnv.executeSql("INSERT INTO T4 VALUES " + "('pk1',1, 2.0), " + "('pk1',1, 2.0)").await(); + result = iteratorToList(bEnv.from("T4").execute().collect()); + assertThat(result).containsExactlyInAnyOrder(Row.of("pk1", 1, 4.0)); + } + + @Test + public void testMergeRead() throws ExecutionException, InterruptedException { + List<Row> result; + // T3 + bEnv.executeSql("INSERT INTO T3 VALUES ('pk1',1, 2)").await(); + bEnv.executeSql("INSERT INTO T3 VALUES ('pk1',1, 4)").await(); + bEnv.executeSql("INSERT INTO T3 VALUES ('pk1',2, 0)").await(); + result = iteratorToList(bEnv.from("T3").execute().collect()); + assertThat(result).containsExactlyInAnyOrder(Row.of("pk1", 4L, 6)); + + // T4 + bEnv.executeSql("INSERT INTO T4 VALUES ('pk1',1, 2.0)").await(); + bEnv.executeSql("INSERT INTO T4 VALUES ('pk1',1, 4.0)").await(); + bEnv.executeSql("INSERT INTO T4 VALUES ('pk1',1, 0.0)").await(); + result = iteratorToList(bEnv.from("T4").execute().collect()); + assertThat(result).containsExactlyInAnyOrder(Row.of("pk1", 1, 6.0)); + } + + @Test + public void testMergeCompaction() throws ExecutionException, InterruptedException { + List<Row> result; + + // T3 + // Wait compaction + bEnv.executeSql("ALTER TABLE T3 SET ('commit.force-compact'='true')"); + + // key pk1 + bEnv.executeSql("INSERT INTO T3 VALUES ('pk1', 3, 1)").await(); + bEnv.executeSql("INSERT INTO T3 VALUES ('pk1', 4, 5)").await(); + bEnv.executeSql("INSERT INTO T3 VALUES ('pk1', 4, 6)").await(); + + // key pk2 + bEnv.executeSql("INSERT INTO T3 VALUES ('pk2', 6,7)").await(); + bEnv.executeSql("INSERT INTO T3 VALUES ('pk2', 9,0)").await(); + bEnv.executeSql("INSERT INTO T3 VALUES ('pk2', 4,4)").await(); + + result = iteratorToList(bEnv.from("T3").execute().collect()); + assertThat(result) + .containsExactlyInAnyOrder(Row.of("pk1", 11L, 12), Row.of("pk2", 19L, 11)); + + // T4 + // Wait compaction + bEnv.executeSql("ALTER TABLE T4 SET ('commit.force-compact'='true')"); + + // key pk1_3 + bEnv.executeSql("INSERT INTO T4 VALUES ('pk1', 3, 1.0)").await(); + // key pk1_4 + bEnv.executeSql("INSERT INTO T4 VALUES ('pk1', 4, 5.0)").await(); + bEnv.executeSql("INSERT INTO T4 VALUES ('pk1', 4, 6.0)").await(); + // key pk2_4 + bEnv.executeSql("INSERT INTO T4 VALUES ('pk2', 4,4.0)").await(); + // key pk2_2 + bEnv.executeSql("INSERT INTO T4 VALUES ('pk2', 2,7.0)").await(); + bEnv.executeSql("INSERT INTO T4 VALUES ('pk2', 2,0)").await(); + + result = iteratorToList(bEnv.from("T4").execute().collect()); + assertThat(result) + .containsExactlyInAnyOrder( + Row.of("pk1", 3, 1.0), + Row.of("pk1", 4, 11.0), + Row.of("pk2", 4, 4.0), + Row.of("pk2", 2, 7.0)); + } + + @Test + public void myTest() throws Exception { Review Comment: This is not a valid test. ########## flink-table-store-core/src/main/java/org/apache/flink/table/store/file/mergetree/compact/ColumnAggregateFunctionFactory.java: ########## @@ -0,0 +1,151 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.store.file.mergetree.compact; + +import org.apache.flink.table.types.logical.LogicalType; + +/** Factory for creating {@link ColumnAggregateFunction}s. */ +public class ColumnAggregateFunctionFactory { + /** + * Determine the column aggregation function . + * + * @param kind the kind of aggregation + * @param typeAt the type of the column + * @return the column aggregation function + */ + public static ColumnAggregateFunction<?> getColumnAggregateFunction( + AggregationKind kind, LogicalType typeAt) { + switch (kind) { + case Sum: + return SumColumnAggregateFunctionFactory.getColumnAggregateFunction(typeAt); + case Avg: + return AvgColumnAggregateFunctionFactory.getColumnAggregateFunction(typeAt); + case Max: + return MaxColumnAggregateFunctionFactory.getColumnAggregateFunction(typeAt); + case Min: + return MinColumnAggregateFunctionFactory.getColumnAggregateFunction(typeAt); + default: + throw new IllegalArgumentException("Aggregation kind " + kind + " not supported"); + } + } + + /** AggregateKind is Sum. Determine the column aggregation function . */ + private static class SumColumnAggregateFunctionFactory { + static SumColumnAggregateFunction<?> getColumnAggregateFunction(LogicalType type) { + switch (type.getTypeRoot()) { + case CHAR: + case VARCHAR: + case BOOLEAN: + case BINARY: + case VARBINARY: + case DECIMAL: + case TINYINT: + case SMALLINT: Review Comment: These types can also be supported. The internal class for `DECIMAL` is `DecimalData`. There are methods like `add` and `compare` in `DecimalDataUtils` class. The internal class for `TINYINT` is `byte`. For `SMALLINT` that is `short`. You can support them easily. ########## docs/content/docs/development/create-table.md: ########## @@ -268,3 +268,45 @@ For example, the inputs: Output: - <1, 25.2, 20, 'This is a book'> + +## Aggregation Update + +You can configure partial update from options: Review Comment: > partial update from options I guess you're copying this from the partial update merge engine. Change this description. ########## flink-table-store-core/src/main/java/org/apache/flink/table/store/file/mergetree/compact/AggregationKind.java: ########## @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.store.file.mergetree.compact; + +import java.util.Locale; + +/** Aggregate kinds. */ +public enum AggregationKind { + Sum, + Max, + Min, + Avg; Review Comment: Enum values should be upper-cased letters. Not sure if checkstyle will check this, but in Flink's code base this seems to be a convention. ########## flink-table-store-core/src/main/java/org/apache/flink/table/store/file/mergetree/compact/AggregateMergeFunction.java: ########## @@ -0,0 +1,165 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.store.file.mergetree.compact; + +import org.apache.flink.table.data.GenericRowData; +import org.apache.flink.table.data.RowData; +import org.apache.flink.table.store.file.FileStoreOptions; +import org.apache.flink.table.types.logical.LogicalType; +import org.apache.flink.table.types.logical.RowType; + +import javax.annotation.Nullable; + +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + +/** + * A {@link MergeFunction} where key is primary key (unique) and value is the partial record, + * aggregate specifies field on merge. + */ +@SuppressWarnings("checkstyle:RegexpSingleline") +public class AggregateMergeFunction implements MergeFunction { + + private static final long serialVersionUID = 1L; + + private final RowData.FieldGetter[] getters; + + private final RowType rowType; + private final ArrayList<ColumnAggregateFunction<?>> aggregateFunctions; + private final boolean[] isPrimaryKey; + private final RowType primaryKeyType; + private transient GenericRowData row; + private final Map<String, AggregationKind> aggregationKindMap; + + public AggregateMergeFunction( + RowType primaryKeyType, + RowType rowType, + Map<String, AggregationKind> aggregationKindMap) { + this.primaryKeyType = primaryKeyType; + this.rowType = rowType; + this.aggregationKindMap = aggregationKindMap; + + List<LogicalType> fieldTypes = rowType.getChildren(); + this.getters = new RowData.FieldGetter[fieldTypes.size()]; + for (int i = 0; i < fieldTypes.size(); i++) { + getters[i] = RowData.createFieldGetter(fieldTypes.get(i), i); + } + + this.isPrimaryKey = new boolean[this.getters.length]; + List<String> rowNames = rowType.getFieldNames(); + for (String primaryKeyName : primaryKeyType.getFieldNames()) { + isPrimaryKey[rowNames.indexOf(primaryKeyName)] = true; + } + + this.aggregateFunctions = new ArrayList<>(rowType.getFieldCount()); + for (int i = 0; i < rowType.getFieldCount(); i++) { + ColumnAggregateFunction<?> f = null; + if (aggregationKindMap.containsKey(rowNames.get(i))) { + f = + ColumnAggregateFunctionFactory.getColumnAggregateFunction( + aggregationKindMap.get(rowNames.get(i)), rowType.getTypeAt(i)); + } else { + if (!isPrimaryKey[i]) { + throw new IllegalArgumentException( + "should set aggregate function for every column not part of primary key"); + } + } + aggregateFunctions.add(f); + } + } + + @Override + public void reset() { + this.row = new GenericRowData(getters.length); + } + + @Override + public void add(RowData value) { + for (int i = 0; i < getters.length; i++) { + Object currentField = getters[i].getFieldOrNull(value); + ColumnAggregateFunction<?> f = aggregateFunctions.get(i); + if (isPrimaryKey[i]) { + // primary key + if (currentField != null) { + row.setField(i, currentField); + } + } else { + if (f != null) { + f.reset(); + Object oldValue = row.getField(i); + if (oldValue != null) { + f.aggregate(oldValue); + } + switch (value.getRowKind()) { + case INSERT: + f.aggregate(currentField); + break; + case DELETE: + case UPDATE_AFTER: + case UPDATE_BEFORE: + default: + throw new UnsupportedOperationException( + "Unsupported row kind: " + row.getRowKind()); + } + Object result = f.getResult(); + if (result != null) { + row.setField(i, result); Review Comment: Value of `row` is inspected only when `AggregateMergeFunction#getValue` is called. You've stored aggregated results in `f`, why resetting it every time and aggregate twice? My suggestion: * In `AggregateMergeFunction#reset`, reset every aggregate function. * In `AggregateMergeFunction#add`, aggregate each field into the corresponding function. Current results are now stored in the function. * In `AggregateMergeFunction#getValue`, move aggregated results from functions into the row. ########## flink-table-store-core/src/main/java/org/apache/flink/table/store/file/mergetree/compact/AggregationKind.java: ########## @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.store.file.mergetree.compact; + +import java.util.Locale; + +/** Aggregate kinds. */ +public enum AggregationKind { + Sum, + Max, + Min, + Avg; + + public static AggregationKind fromString(String name) { + switch (name.toLowerCase(Locale.ROOT)) { + case "sum": + return Sum; + case "max": + return Max; + case "min": + return Min; + case "avg": + return Avg; + default: + throw new IllegalArgumentException("Unknown aggregation kind: " + name); + } + } Review Comment: No need for this method. Callers can call `AggregateKind.valueOf(name.toUpperCase())` directly. ########## flink-table-store-core/src/main/java/org/apache/flink/table/store/file/mergetree/compact/AggregateMergeFunction.java: ########## @@ -0,0 +1,165 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.store.file.mergetree.compact; + +import org.apache.flink.table.data.GenericRowData; +import org.apache.flink.table.data.RowData; +import org.apache.flink.table.store.file.FileStoreOptions; +import org.apache.flink.table.types.logical.LogicalType; +import org.apache.flink.table.types.logical.RowType; + +import javax.annotation.Nullable; + +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + +/** + * A {@link MergeFunction} where key is primary key (unique) and value is the partial record, + * aggregate specifies field on merge. + */ +@SuppressWarnings("checkstyle:RegexpSingleline") +public class AggregateMergeFunction implements MergeFunction { + + private static final long serialVersionUID = 1L; + + private final RowData.FieldGetter[] getters; + + private final RowType rowType; + private final ArrayList<ColumnAggregateFunction<?>> aggregateFunctions; + private final boolean[] isPrimaryKey; + private final RowType primaryKeyType; + private transient GenericRowData row; + private final Map<String, AggregationKind> aggregationKindMap; + + public AggregateMergeFunction( + RowType primaryKeyType, + RowType rowType, + Map<String, AggregationKind> aggregationKindMap) { + this.primaryKeyType = primaryKeyType; + this.rowType = rowType; + this.aggregationKindMap = aggregationKindMap; + + List<LogicalType> fieldTypes = rowType.getChildren(); + this.getters = new RowData.FieldGetter[fieldTypes.size()]; + for (int i = 0; i < fieldTypes.size(); i++) { + getters[i] = RowData.createFieldGetter(fieldTypes.get(i), i); + } + + this.isPrimaryKey = new boolean[this.getters.length]; + List<String> rowNames = rowType.getFieldNames(); + for (String primaryKeyName : primaryKeyType.getFieldNames()) { + isPrimaryKey[rowNames.indexOf(primaryKeyName)] = true; + } + + this.aggregateFunctions = new ArrayList<>(rowType.getFieldCount()); + for (int i = 0; i < rowType.getFieldCount(); i++) { + ColumnAggregateFunction<?> f = null; + if (aggregationKindMap.containsKey(rowNames.get(i))) { + f = + ColumnAggregateFunctionFactory.getColumnAggregateFunction( + aggregationKindMap.get(rowNames.get(i)), rowType.getTypeAt(i)); + } else { + if (!isPrimaryKey[i]) { + throw new IllegalArgumentException( + "should set aggregate function for every column not part of primary key"); + } + } + aggregateFunctions.add(f); + } + } + + @Override + public void reset() { + this.row = new GenericRowData(getters.length); + } + + @Override + public void add(RowData value) { + for (int i = 0; i < getters.length; i++) { + Object currentField = getters[i].getFieldOrNull(value); + ColumnAggregateFunction<?> f = aggregateFunctions.get(i); + if (isPrimaryKey[i]) { + // primary key + if (currentField != null) { + row.setField(i, currentField); + } + } else { + if (f != null) { Review Comment: `f` should never be null in this case. Use `Preconditions.checkNotNull(f, "Aggregate function should never be null. This is unexpected.")` if you need to check this. ########## flink-table-store-connector/src/test/java/org/apache/flink/table/store/connector/AggregationITCase.java: ########## @@ -0,0 +1,206 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.store.connector; + +import org.apache.flink.types.Row; + +import org.junit.Test; + +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.ExecutionException; + +import static org.apache.flink.util.CollectionUtil.iteratorToList; +import static org.assertj.core.api.Assertions.assertThat; + +/** ITCase for partial update. */ +public class AggregationITCase extends FileStoreTableITCase { Review Comment: You now support other aggregate functions like `min` and `max`. Add tests for them. Remember any changes you make must come with a test. You also support different aggregate functions for different columns. Also add a test case for this. ########## docs/content/docs/development/create-table.md: ########## @@ -268,3 +268,45 @@ For example, the inputs: Output: - <1, 25.2, 20, 'This is a book'> + +## Aggregation Update + +You can configure partial update from options: + +```sql +CREATE TABLE MyTable ( + a STRING, + b INT, + c INT, + PRIMARY KEY (a) NOT ENFORCED +) WITH ( + 'merge-engine' = 'aggregation', + 'b.aggregate-function' = 'sum', + 'c.aggregate-function' = 'sum' +); +``` +{{< hint info >}} +__Note:__Aggregate updates are only supported for tables with primary keys. +{{< /hint >}} + +{{< hint info >}} +__Note:__Aggregate updates do not support streaming consumption. Review Comment: Add a test case for streaming consumption. This test should throw an exception telling the user that we don't support this. ########## flink-table-store-core/src/main/java/org/apache/flink/table/store/file/mergetree/compact/AggregateMergeFunction.java: ########## @@ -0,0 +1,165 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.store.file.mergetree.compact; + +import org.apache.flink.table.data.GenericRowData; +import org.apache.flink.table.data.RowData; +import org.apache.flink.table.store.file.FileStoreOptions; +import org.apache.flink.table.types.logical.LogicalType; +import org.apache.flink.table.types.logical.RowType; + +import javax.annotation.Nullable; + +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + +/** + * A {@link MergeFunction} where key is primary key (unique) and value is the partial record, + * aggregate specifies field on merge. + */ +@SuppressWarnings("checkstyle:RegexpSingleline") +public class AggregateMergeFunction implements MergeFunction { + + private static final long serialVersionUID = 1L; + + private final RowData.FieldGetter[] getters; + + private final RowType rowType; + private final ArrayList<ColumnAggregateFunction<?>> aggregateFunctions; + private final boolean[] isPrimaryKey; + private final RowType primaryKeyType; + private transient GenericRowData row; + private final Map<String, AggregationKind> aggregationKindMap; + + public AggregateMergeFunction( + RowType primaryKeyType, + RowType rowType, + Map<String, AggregationKind> aggregationKindMap) { + this.primaryKeyType = primaryKeyType; + this.rowType = rowType; + this.aggregationKindMap = aggregationKindMap; + + List<LogicalType> fieldTypes = rowType.getChildren(); + this.getters = new RowData.FieldGetter[fieldTypes.size()]; + for (int i = 0; i < fieldTypes.size(); i++) { + getters[i] = RowData.createFieldGetter(fieldTypes.get(i), i); + } + + this.isPrimaryKey = new boolean[this.getters.length]; + List<String> rowNames = rowType.getFieldNames(); + for (String primaryKeyName : primaryKeyType.getFieldNames()) { + isPrimaryKey[rowNames.indexOf(primaryKeyName)] = true; + } + + this.aggregateFunctions = new ArrayList<>(rowType.getFieldCount()); + for (int i = 0; i < rowType.getFieldCount(); i++) { + ColumnAggregateFunction<?> f = null; + if (aggregationKindMap.containsKey(rowNames.get(i))) { + f = + ColumnAggregateFunctionFactory.getColumnAggregateFunction( + aggregationKindMap.get(rowNames.get(i)), rowType.getTypeAt(i)); + } else { + if (!isPrimaryKey[i]) { + throw new IllegalArgumentException( + "should set aggregate function for every column not part of primary key"); + } + } + aggregateFunctions.add(f); + } + } + + @Override + public void reset() { + this.row = new GenericRowData(getters.length); + } + + @Override + public void add(RowData value) { + for (int i = 0; i < getters.length; i++) { + Object currentField = getters[i].getFieldOrNull(value); + ColumnAggregateFunction<?> f = aggregateFunctions.get(i); + if (isPrimaryKey[i]) { + // primary key + if (currentField != null) { + row.setField(i, currentField); + } + } else { + if (f != null) { + f.reset(); + Object oldValue = row.getField(i); + if (oldValue != null) { + f.aggregate(oldValue); + } + switch (value.getRowKind()) { + case INSERT: + f.aggregate(currentField); + break; + case DELETE: + case UPDATE_AFTER: + case UPDATE_BEFORE: + default: + throw new UnsupportedOperationException( + "Unsupported row kind: " + row.getRowKind()); + } + Object result = f.getResult(); + if (result != null) { Review Comment: In SQL, if one of the input value of some (well, most of, except `count`) aggregate function is null then the final result will be null. No need to filter out null values but you may need to deal with null values in your aggregate functions. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org