This is an automated email from the ASF dual-hosted git repository. lindong pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/flink-ml.git
commit ba7760792827634664bb1962cc7ca6e9c161f255 Author: zhangzp <zhangzhipe...@gmail.com> AuthorDate: Mon Jun 20 09:42:17 2022 +0800 [FLINK-27877] Improve performance for StringIndexer --- .../ml/feature/stringindexer/StringIndexer.java | 219 +++++++++++---------- 1 file changed, 117 insertions(+), 102 deletions(-) diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/stringindexer/StringIndexer.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/stringindexer/StringIndexer.java index ee560e0..c8312fa 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/stringindexer/StringIndexer.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/stringindexer/StringIndexer.java @@ -18,31 +18,40 @@ package org.apache.flink.ml.feature.stringindexer; -import org.apache.flink.api.common.functions.FlatMapFunction; -import org.apache.flink.api.common.functions.MapPartitionFunction; -import org.apache.flink.api.java.functions.KeySelector; +import org.apache.flink.api.common.functions.MapFunction; +import org.apache.flink.api.common.functions.ReduceFunction; +import org.apache.flink.api.common.state.ListState; +import org.apache.flink.api.common.state.ListStateDescriptor; +import org.apache.flink.api.common.typeinfo.TypeHint; +import org.apache.flink.api.common.typeinfo.TypeInformation; import org.apache.flink.api.java.tuple.Tuple2; -import org.apache.flink.api.java.tuple.Tuple3; +import org.apache.flink.iteration.operator.OperatorStateUtils; import org.apache.flink.ml.api.Estimator; import org.apache.flink.ml.common.datastream.DataStreamUtils; import org.apache.flink.ml.common.param.HasHandleInvalid; import org.apache.flink.ml.param.Param; import org.apache.flink.ml.util.ParamUtils; import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.runtime.state.StateInitializationContext; +import org.apache.flink.runtime.state.StateSnapshotContext; import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.operators.AbstractStreamOperator; +import org.apache.flink.streaming.api.operators.BoundedOneInput; +import org.apache.flink.streaming.api.operators.OneInputStreamOperator; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; import org.apache.flink.table.api.Table; import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; import org.apache.flink.table.api.internal.TableImpl; import org.apache.flink.types.Row; -import org.apache.flink.util.Collector; import org.apache.flink.util.Preconditions; import java.io.IOException; import java.util.ArrayList; +import java.util.Collections; import java.util.Comparator; import java.util.HashMap; -import java.util.List; import java.util.Map; +import java.util.Map.Entry; /** * An Estimator which implements the string indexing algorithm. @@ -91,19 +100,34 @@ public class StringIndexer StreamTableEnvironment tEnv = (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment(); - DataStream<Tuple2<Integer, String>> columnIdAndString = - tEnv.toDataStream(inputs[0]).flatMap(new ExtractColumnIdAndString(inputCols)); + DataStream<HashMap<String, Long>[]> localCountedString = + tEnv.toDataStream(inputs[0]) + .transform( + "countStringOperator", + TypeInformation.of(new TypeHint<HashMap<String, Long>[]>() {}), + new CountStringOperator(inputCols)); - DataStream<Tuple3<Integer, String, Long>> columnIdAndStringAndCnt = - DataStreamUtils.mapPartition( - columnIdAndString.keyBy( - (KeySelector<Tuple2<Integer, String>, Integer>) Tuple2::hashCode), - new CountStringsByColumn(inputCols.length)); + DataStream<HashMap<String, Long>[]> countedString = + DataStreamUtils.reduce( + localCountedString, + (ReduceFunction<HashMap<String, Long>[]>) + (value1, value2) -> { + for (int i = 0; i < value1.length; i++) { + for (Entry<String, Long> stringAndCnt : + value2[i].entrySet()) { + value1[i].compute( + stringAndCnt.getKey(), + (k, v) -> + (v == null + ? stringAndCnt.getValue() + : v + stringAndCnt.getValue())); + } + } + return value1; + }); DataStream<StringIndexerModelData> modelData = - DataStreamUtils.mapPartition( - columnIdAndStringAndCnt, - new GenerateModel(inputCols.length, getStringOrderType())); + countedString.map(new ModelGenerator(getStringOrderType())); modelData.getTransformation().setParallelism(1); StringIndexerModel model = @@ -112,38 +136,93 @@ public class StringIndexer return model; } + /** Computes the occurrence time of each string by columns. */ + private static class CountStringOperator extends AbstractStreamOperator<HashMap<String, Long>[]> + implements OneInputStreamOperator<Row, HashMap<String, Long>[]>, BoundedOneInput { + /** The name of input columns. */ + private final String[] inputCols; + /** The occurrence time of each string by column. */ + private HashMap<String, Long>[] stringCntByColumn; + /** The state of stringCntByColumn. */ + private ListState<HashMap<String, Long>[]> stringCntByColumnState; + + public CountStringOperator(String[] inputCols) { + this.inputCols = inputCols; + stringCntByColumn = new HashMap[inputCols.length]; + for (int i = 0; i < stringCntByColumn.length; i++) { + stringCntByColumn[i] = new HashMap<>(); + } + } + + @Override + public void endInput() { + output.collect(new StreamRecord<>(stringCntByColumn)); + stringCntByColumnState.clear(); + } + + @Override + public void processElement(StreamRecord<Row> element) { + Row r = element.getValue(); + for (int i = 0; i < inputCols.length; i++) { + Object objVal = r.getField(inputCols[i]); + String stringVal; + if (objVal instanceof String) { + stringVal = (String) objVal; + } else if (objVal instanceof Number) { + stringVal = String.valueOf(objVal); + } else { + throw new RuntimeException( + "The input column only supports string and numeric type."); + } + stringCntByColumn[i].compute(stringVal, (k, v) -> (v == null ? 1 : v + 1)); + } + } + + @Override + public void initializeState(StateInitializationContext context) throws Exception { + super.initializeState(context); + stringCntByColumnState = + context.getOperatorStateStore() + .getListState( + new ListStateDescriptor<>( + "stringCntByColumnState", + TypeInformation.of( + new TypeHint<HashMap<String, Long>[]>() {}))); + + OperatorStateUtils.getUniqueElement(stringCntByColumnState, "stringCntByColumnState") + .ifPresent(x -> stringCntByColumn = x); + } + + @Override + public void snapshotState(StateSnapshotContext context) throws Exception { + super.snapshotState(context); + stringCntByColumnState.update(Collections.singletonList(stringCntByColumn)); + } + } + /** * Merges all the extracted strings and generates the {@link StringIndexerModelData} according * to the specified string order type. */ - private static class GenerateModel - implements MapPartitionFunction<Tuple3<Integer, String, Long>, StringIndexerModelData> { - private final int numCols; + private static class ModelGenerator + implements MapFunction<HashMap<String, Long>[], StringIndexerModelData> { private final String stringOrderType; - public GenerateModel(int numCols, String stringOrderType) { - this.numCols = numCols; + public ModelGenerator(String stringOrderType) { this.stringOrderType = stringOrderType; } @Override - @SuppressWarnings("unchecked") - public void mapPartition( - Iterable<Tuple3<Integer, String, Long>> values, - Collector<StringIndexerModelData> out) { + public StringIndexerModelData map(HashMap<String, Long>[] value) { + int numCols = value.length; String[][] stringArrays = new String[numCols][]; - ArrayList<Tuple2<String, Long>>[] stringsAndCntsByColumn = new ArrayList[numCols]; + ArrayList<Tuple2<String, Long>> stringsAndCnts = new ArrayList<>(); for (int i = 0; i < numCols; i++) { - stringsAndCntsByColumn[i] = new ArrayList<>(); - } - - for (Tuple3<Integer, String, Long> colIdAndStringAndCnt : values) { - stringsAndCntsByColumn[colIdAndStringAndCnt.f0].add( - Tuple2.of(colIdAndStringAndCnt.f1, colIdAndStringAndCnt.f2)); - } - - for (int i = 0; i < stringsAndCntsByColumn.length; i++) { - List<Tuple2<String, Long>> stringsAndCnts = stringsAndCntsByColumn[i]; + stringsAndCnts.clear(); + stringsAndCnts.ensureCapacity(value[i].size()); + for (Map.Entry<String, Long> entry : value[i].entrySet()) { + stringsAndCnts.add(Tuple2.of(entry.getKey(), entry.getValue())); + } switch (stringOrderType) { case ALPHABET_ASC_ORDER: stringsAndCnts.sort(Comparator.comparing(valAndCnt -> valAndCnt.f0)); @@ -171,74 +250,10 @@ public class StringIndexer + stringOrderType + "."); } - - stringArrays[i] = new String[stringsAndCnts.size()]; - for (int stringId = 0; stringId < stringArrays[i].length; stringId++) { - stringArrays[i][stringId] = stringsAndCnts.get(stringId).f0; - } - } - - out.collect(new StringIndexerModelData(stringArrays)); - } - } - - /** Computes the frequency of strings in each column. */ - private static class CountStringsByColumn - implements MapPartitionFunction< - Tuple2<Integer, String>, Tuple3<Integer, String, Long>> { - private final int numCols; - - public CountStringsByColumn(int numCols) { - this.numCols = numCols; - } - - @Override - @SuppressWarnings("unchecked") - public void mapPartition( - Iterable<Tuple2<Integer, String>> values, - Collector<Tuple3<Integer, String, Long>> out) { - HashMap<String, Long>[] string2CntByColumn = new HashMap[numCols]; - for (int i = 0; i < numCols; i++) { - string2CntByColumn[i] = new HashMap<>(); + stringArrays[i] = stringsAndCnts.stream().map(x -> x.f0).toArray(String[]::new); } - for (Tuple2<Integer, String> columnIdAndString : values) { - int colId = columnIdAndString.f0; - String stringVal = columnIdAndString.f1; - long cnt = string2CntByColumn[colId].getOrDefault(stringVal, 0L) + 1; - string2CntByColumn[colId].put(stringVal, cnt); - } - for (int i = 0; i < numCols; i++) { - for (Map.Entry<String, Long> entry : string2CntByColumn[i].entrySet()) { - out.collect(Tuple3.of(i, entry.getKey(), entry.getValue())); - } - } - } - } - /** Extracts strings by column. */ - private static class ExtractColumnIdAndString - implements FlatMapFunction<Row, Tuple2<Integer, String>> { - private final String[] inputCols; - - public ExtractColumnIdAndString(String[] inputCols) { - this.inputCols = inputCols; - } - - @Override - public void flatMap(Row row, Collector<Tuple2<Integer, String>> out) { - for (int i = 0; i < inputCols.length; i++) { - Object objVal = row.getField(inputCols[i]); - String stringVal; - if (objVal instanceof String) { - stringVal = (String) objVal; - } else if (objVal instanceof Number) { - stringVal = String.valueOf(objVal); - } else { - throw new RuntimeException( - "The input column only supports string and numeric type."); - } - out.collect(Tuple2.of(i, stringVal)); - } + return new StringIndexerModelData(stringArrays); } } }