This is an automated email from the ASF dual-hosted git repository. lindong pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/flink-ml.git
The following commit(s) were added to refs/heads/master by this push: new 0769392 [FLINK-25616] Add Transformer for VectorAssembler 0769392 is described below commit 076939285ffb7e5167371ae4433a1b4f394e6753 Author: weibo <weibo....@alibaba-inc.com> AuthorDate: Sat Apr 2 15:32:36 2022 +0800 [FLINK-25616] Add Transformer for VectorAssembler This closes #56. --- .../apache/flink/ml/common/param/HasOutputCol.java | 39 +++++ .../feature/vectorassembler/VectorAssembler.java | 182 +++++++++++++++++++++ .../vectorassembler/VectorAssemblerParams.java | 63 +++++++ .../flink/ml/feature/VectorAssemblerTest.java | 178 ++++++++++++++++++++ 4 files changed, 462 insertions(+) diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasOutputCol.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasOutputCol.java new file mode 100644 index 0000000..e191058 --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasOutputCol.java @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.param; + +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.param.ParamValidators; +import org.apache.flink.ml.param.StringParam; +import org.apache.flink.ml.param.WithParams; + +/** Interface for the shared outputCol param. */ +public interface HasOutputCol<T> extends WithParams<T> { + Param<String> OUTPUT_COL = + new StringParam( + "outputCol", "Output column name.", "output", ParamValidators.notNull()); + + default String getOutputCol() { + return get(OUTPUT_COL); + } + + default T setOutputCol(String value) { + return set(OUTPUT_COL, value); + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/vectorassembler/VectorAssembler.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/vectorassembler/VectorAssembler.java new file mode 100644 index 0000000..61d84d6 --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/vectorassembler/VectorAssembler.java @@ -0,0 +1,182 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.feature.vectorassembler; + +import org.apache.flink.api.common.functions.FlatMapFunction; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.java.typeutils.RowTypeInfo; +import org.apache.flink.ml.api.Transformer; +import org.apache.flink.ml.common.datastream.TableUtils; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.SparseVector; +import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.util.ParamUtils; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.types.Row; +import org.apache.flink.util.Collector; +import org.apache.flink.util.Preconditions; + +import org.apache.commons.lang3.ArrayUtils; + +import java.io.IOException; +import java.util.HashMap; +import java.util.LinkedHashMap; +import java.util.Map; + +/** + * A feature transformer that combines a given list of input columns into a vector column. Types of + * input columns must be either vector or numerical value. + */ +public class VectorAssembler + implements Transformer<VectorAssembler>, VectorAssemblerParams<VectorAssembler> { + private final Map<Param<?>, Object> paramMap = new HashMap<>(); + private static final double RATIO = 1.5; + + public VectorAssembler() { + ParamUtils.initializeMapWithDefaultValues(paramMap, this); + } + + @Override + public Table[] transform(Table... inputs) { + Preconditions.checkArgument(inputs.length == 1); + StreamTableEnvironment tEnv = + (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment(); + RowTypeInfo inputTypeInfo = TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema()); + RowTypeInfo outputTypeInfo = + new RowTypeInfo( + ArrayUtils.addAll( + inputTypeInfo.getFieldTypes(), TypeInformation.of(Vector.class)), + ArrayUtils.addAll(inputTypeInfo.getFieldNames(), getOutputCol())); + DataStream<Row> output = + tEnv.toDataStream(inputs[0]) + .flatMap( + new AssemblerFunc(getInputCols(), getHandleInvalid()), + outputTypeInfo); + Table outputTable = tEnv.fromDataStream(output); + return new Table[] {outputTable}; + } + + private static class AssemblerFunc implements FlatMapFunction<Row, Row> { + private final String[] inputCols; + private final String handleInvalid; + + public AssemblerFunc(String[] inputCols, String handleInvalid) { + this.inputCols = inputCols; + this.handleInvalid = handleInvalid; + } + + @Override + public void flatMap(Row value, Collector<Row> out) throws Exception { + try { + Object[] objects = new Object[inputCols.length]; + for (int i = 0; i < objects.length; ++i) { + objects[i] = value.getField(inputCols[i]); + } + Vector assembledVector = assemble(objects); + out.collect(Row.join(value, Row.of(assembledVector))); + } catch (Exception e) { + switch (handleInvalid) { + case VectorAssemblerParams.ERROR_INVALID: + throw e; + case VectorAssemblerParams.SKIP_INVALID: + return; + case VectorAssemblerParams.KEEP_INVALID: + out.collect(Row.join(value, Row.of((Object) null))); + return; + default: + throw new UnsupportedOperationException( + "handleInvalid=" + handleInvalid + " is not supported"); + } + } + } + } + + @Override + public void save(String path) throws IOException { + ReadWriteUtils.saveMetadata(this, path); + } + + public static VectorAssembler load(StreamTableEnvironment env, String path) throws IOException { + return ReadWriteUtils.loadStageParam(path); + } + + @Override + public Map<Param<?>, Object> getParamMap() { + return paramMap; + } + + private static Vector assemble(Object[] objects) { + int offset = 0; + Map<Integer, Double> map = new LinkedHashMap<>(objects.length); + for (Object object : objects) { + Preconditions.checkNotNull(object, "Input column value should not be null."); + if (object instanceof Number) { + map.put(offset++, ((Number) object).doubleValue()); + } else if (object instanceof Vector) { + offset = appendVector((Vector) object, map, offset); + } else { + throw new IllegalArgumentException("Input type has not been supported yet."); + } + } + + if (map.size() * RATIO > offset) { + DenseVector assembledVector = new DenseVector(offset); + for (int key : map.keySet()) { + assembledVector.values[key] = map.get(key); + } + return assembledVector; + } else { + return convertMapToSparseVector(offset, map); + } + } + + private static int appendVector(Vector vec, Map<Integer, Double> map, int offset) { + if (vec instanceof SparseVector) { + SparseVector sparseVector = (SparseVector) vec; + int[] indices = sparseVector.indices; + double[] values = sparseVector.values; + for (int i = 0; i < indices.length; ++i) { + map.put(offset + indices[i], values[i]); + } + offset += sparseVector.size(); + } else { + DenseVector denseVector = (DenseVector) vec; + for (int i = 0; i < denseVector.size(); ++i) { + map.put(offset++, denseVector.values[i]); + } + } + return offset; + } + + private static SparseVector convertMapToSparseVector(int size, Map<Integer, Double> map) { + int[] indices = new int[map.size()]; + double[] values = new double[map.size()]; + int offset = 0; + for (Map.Entry<Integer, Double> entry : map.entrySet()) { + indices[offset] = entry.getKey(); + values[offset++] = entry.getValue(); + } + return new SparseVector(size, indices, values); + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/vectorassembler/VectorAssemblerParams.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/vectorassembler/VectorAssemblerParams.java new file mode 100644 index 0000000..5e2cda4 --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/vectorassembler/VectorAssemblerParams.java @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.feature.vectorassembler; + +import org.apache.flink.ml.common.param.HasInputCols; +import org.apache.flink.ml.common.param.HasOutputCol; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.param.ParamValidators; +import org.apache.flink.ml.param.StringParam; + +/** + * Params of VectorAssembler. + * + * @param <T> The class type of this instance. + */ +public interface VectorAssemblerParams<T> extends HasInputCols<T>, HasOutputCol<T> { + + String ERROR_INVALID = "error"; + String SKIP_INVALID = "skip"; + String KEEP_INVALID = "keep"; + + /** + * Supported options and the corresponding behavior to handle invalid entries is listed as + * follows. + * + * <ul> + * <li>error: raise an exception. + * <li>skip: filter out rows with bad values. + * <li>keep: output bad rows with output column's value set to null. + * </ul> + */ + Param<String> HANDLE_INVALID = + new StringParam( + "handleInvalid", + "Strategy to handle invalid entries.", + ERROR_INVALID, + ParamValidators.inArray(ERROR_INVALID, SKIP_INVALID, KEEP_INVALID)); + + default String getHandleInvalid() { + return get(HANDLE_INVALID); + } + + default T setHandleInvalid(String value) { + set(HANDLE_INVALID, value); + return (T) this; + } +} diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/VectorAssemblerTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/VectorAssemblerTest.java new file mode 100644 index 0000000..193077c --- /dev/null +++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/VectorAssemblerTest.java @@ -0,0 +1,178 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.feature; + +import org.apache.flink.api.common.restartstrategy.RestartStrategies; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.ml.common.param.HasHandleInvalid; +import org.apache.flink.ml.feature.vectorassembler.VectorAssembler; +import org.apache.flink.ml.feature.vectorassembler.VectorAssemblerParams; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.SparseVector; +import org.apache.flink.ml.linalg.Vectors; +import org.apache.flink.ml.util.StageTestUtils; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.test.util.AbstractTestBase; +import org.apache.flink.types.Row; + +import org.apache.commons.collections.IteratorUtils; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import java.util.Arrays; +import java.util.List; + +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNull; + +/** Tests VectorAssembler. */ +public class VectorAssemblerTest extends AbstractTestBase { + + private StreamTableEnvironment tEnv; + private Table inputDataTable; + + private static final List<Row> INPUT_DATA = + Arrays.asList( + Row.of( + 0, + Vectors.dense(2.1, 3.1), + 1.0, + Vectors.sparse(5, new int[] {3}, new double[] {1.0})), + Row.of( + 1, + Vectors.dense(2.1, 3.1), + 1.0, + Vectors.sparse( + 5, new int[] {4, 2, 3, 1}, new double[] {4.0, 2.0, 3.0, 1.0})), + Row.of(2, null, null, null)); + + private static final SparseVector EXPECTED_OUTPUT_DATA_1 = + Vectors.sparse(8, new int[] {0, 1, 2, 6}, new double[] {2.1, 3.1, 1.0, 1.0}); + private static final DenseVector EXPECTED_OUTPUT_DATA_2 = + Vectors.dense(2.1, 3.1, 1.0, 0.0, 1.0, 2.0, 3.0, 4.0); + + @Before + public void before() { + Configuration config = new Configuration(); + config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true); + StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(config); + env.setParallelism(4); + env.enableCheckpointing(100); + env.setRestartStrategy(RestartStrategies.noRestart()); + tEnv = StreamTableEnvironment.create(env); + DataStream<Row> dataStream = env.fromCollection(INPUT_DATA); + inputDataTable = tEnv.fromDataStream(dataStream).as("id", "vec", "num", "sparseVec"); + } + + private void verifyOutputResult(Table output, String outputCol, int outputSize) + throws Exception { + DataStream<Row> dataStream = tEnv.toDataStream(output); + List<Row> results = IteratorUtils.toList(dataStream.executeAndCollect()); + assertEquals(outputSize, results.size()); + for (Row result : results) { + if (result.getField(0) == (Object) 0) { + assertEquals(EXPECTED_OUTPUT_DATA_1, result.getField(outputCol)); + } else if (result.getField(0) == (Object) 1) { + assertEquals(EXPECTED_OUTPUT_DATA_2, result.getField(outputCol)); + } else { + assertNull(result.getField(outputCol)); + } + } + } + + @Test + public void testParam() { + VectorAssembler vectorAssembler = new VectorAssembler(); + assertEquals(HasHandleInvalid.ERROR_INVALID, vectorAssembler.getHandleInvalid()); + assertEquals("output", vectorAssembler.getOutputCol()); + vectorAssembler + .setInputCols("vec", "num", "sparseVec") + .setOutputCol("assembledVec") + .setHandleInvalid(HasHandleInvalid.SKIP_INVALID); + assertArrayEquals(new String[] {"vec", "num", "sparseVec"}, vectorAssembler.getInputCols()); + assertEquals(HasHandleInvalid.SKIP_INVALID, vectorAssembler.getHandleInvalid()); + assertEquals("assembledVec", vectorAssembler.getOutputCol()); + } + + @Test + public void testKeepInvalid() throws Exception { + VectorAssembler vectorAssembler = + new VectorAssembler() + .setInputCols("vec", "num", "sparseVec") + .setOutputCol("assembledVec") + .setHandleInvalid(VectorAssemblerParams.KEEP_INVALID); + Table output = vectorAssembler.transform(inputDataTable)[0]; + assertEquals( + Arrays.asList("id", "vec", "num", "sparseVec", "assembledVec"), + output.getResolvedSchema().getColumnNames()); + verifyOutputResult(output, vectorAssembler.getOutputCol(), 3); + } + + @Test + public void testErrorInvalid() { + VectorAssembler vectorAssembler = + new VectorAssembler() + .setInputCols("vec", "num", "sparseVec") + .setOutputCol("assembledVec") + .setHandleInvalid(HasHandleInvalid.ERROR_INVALID); + try { + Table outputTable = vectorAssembler.transform(inputDataTable)[0]; + outputTable.execute().collect().next(); + Assert.fail("Expected IllegalArgumentException"); + } catch (Exception e) { + assertEquals( + "Input column value should not be null.", + e.getCause().getCause().getCause().getCause().getCause().getMessage()); + } + } + + @Test + public void testSkipInvalid() throws Exception { + VectorAssembler vectorAssembler = + new VectorAssembler() + .setInputCols("vec", "num", "sparseVec") + .setOutputCol("assembledVec") + .setHandleInvalid(VectorAssemblerParams.SKIP_INVALID); + Table output = vectorAssembler.transform(inputDataTable)[0]; + assertEquals( + Arrays.asList("id", "vec", "num", "sparseVec", "assembledVec"), + output.getResolvedSchema().getColumnNames()); + verifyOutputResult(output, vectorAssembler.getOutputCol(), 2); + } + + @Test + public void testSaveLoadAndTransform() throws Exception { + VectorAssembler vectorAssembler = + new VectorAssembler() + .setInputCols("vec", "num", "sparseVec") + .setOutputCol("assembledVec") + .setHandleInvalid(HasHandleInvalid.SKIP_INVALID); + VectorAssembler loadedVectorAssembler = + StageTestUtils.saveAndReload( + tEnv, vectorAssembler, TEMPORARY_FOLDER.newFolder().getAbsolutePath()); + Table output = loadedVectorAssembler.transform(inputDataTable)[0]; + verifyOutputResult(output, loadedVectorAssembler.getOutputCol(), 2); + } +}