This is an automated email from the ASF dual-hosted git repository.

lindong pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink-ml.git


The following commit(s) were added to refs/heads/master by this push:
     new 0769392  [FLINK-25616] Add Transformer for VectorAssembler
0769392 is described below

commit 076939285ffb7e5167371ae4433a1b4f394e6753
Author: weibo <weibo....@alibaba-inc.com>
AuthorDate: Sat Apr 2 15:32:36 2022 +0800

    [FLINK-25616] Add Transformer for VectorAssembler
    
    This closes #56.
---
 .../apache/flink/ml/common/param/HasOutputCol.java |  39 +++++
 .../feature/vectorassembler/VectorAssembler.java   | 182 +++++++++++++++++++++
 .../vectorassembler/VectorAssemblerParams.java     |  63 +++++++
 .../flink/ml/feature/VectorAssemblerTest.java      | 178 ++++++++++++++++++++
 4 files changed, 462 insertions(+)

diff --git 
a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasOutputCol.java 
b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasOutputCol.java
new file mode 100644
index 0000000..e191058
--- /dev/null
+++ 
b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasOutputCol.java
@@ -0,0 +1,39 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.param;
+
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.param.ParamValidators;
+import org.apache.flink.ml.param.StringParam;
+import org.apache.flink.ml.param.WithParams;
+
+/** Interface for the shared outputCol param. */
+public interface HasOutputCol<T> extends WithParams<T> {
+    Param<String> OUTPUT_COL =
+            new StringParam(
+                    "outputCol", "Output column name.", "output", 
ParamValidators.notNull());
+
+    default String getOutputCol() {
+        return get(OUTPUT_COL);
+    }
+
+    default T setOutputCol(String value) {
+        return set(OUTPUT_COL, value);
+    }
+}
diff --git 
a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/vectorassembler/VectorAssembler.java
 
b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/vectorassembler/VectorAssembler.java
new file mode 100644
index 0000000..61d84d6
--- /dev/null
+++ 
b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/vectorassembler/VectorAssembler.java
@@ -0,0 +1,182 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.feature.vectorassembler;
+
+import org.apache.flink.api.common.functions.FlatMapFunction;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.ml.api.Transformer;
+import org.apache.flink.ml.common.datastream.TableUtils;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.SparseVector;
+import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.util.HashMap;
+import java.util.LinkedHashMap;
+import java.util.Map;
+
+/**
+ * A feature transformer that combines a given list of input columns into a 
vector column. Types of
+ * input columns must be either vector or numerical value.
+ */
+public class VectorAssembler
+        implements Transformer<VectorAssembler>, 
VectorAssemblerParams<VectorAssembler> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private static final double RATIO = 1.5;
+
+    public VectorAssembler() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    public Table[] transform(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) 
inputs[0]).getTableEnvironment();
+        RowTypeInfo inputTypeInfo = 
TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema());
+        RowTypeInfo outputTypeInfo =
+                new RowTypeInfo(
+                        ArrayUtils.addAll(
+                                inputTypeInfo.getFieldTypes(), 
TypeInformation.of(Vector.class)),
+                        ArrayUtils.addAll(inputTypeInfo.getFieldNames(), 
getOutputCol()));
+        DataStream<Row> output =
+                tEnv.toDataStream(inputs[0])
+                        .flatMap(
+                                new AssemblerFunc(getInputCols(), 
getHandleInvalid()),
+                                outputTypeInfo);
+        Table outputTable = tEnv.fromDataStream(output);
+        return new Table[] {outputTable};
+    }
+
+    private static class AssemblerFunc implements FlatMapFunction<Row, Row> {
+        private final String[] inputCols;
+        private final String handleInvalid;
+
+        public AssemblerFunc(String[] inputCols, String handleInvalid) {
+            this.inputCols = inputCols;
+            this.handleInvalid = handleInvalid;
+        }
+
+        @Override
+        public void flatMap(Row value, Collector<Row> out) throws Exception {
+            try {
+                Object[] objects = new Object[inputCols.length];
+                for (int i = 0; i < objects.length; ++i) {
+                    objects[i] = value.getField(inputCols[i]);
+                }
+                Vector assembledVector = assemble(objects);
+                out.collect(Row.join(value, Row.of(assembledVector)));
+            } catch (Exception e) {
+                switch (handleInvalid) {
+                    case VectorAssemblerParams.ERROR_INVALID:
+                        throw e;
+                    case VectorAssemblerParams.SKIP_INVALID:
+                        return;
+                    case VectorAssemblerParams.KEEP_INVALID:
+                        out.collect(Row.join(value, Row.of((Object) null)));
+                        return;
+                    default:
+                        throw new UnsupportedOperationException(
+                                "handleInvalid=" + handleInvalid + " is not 
supported");
+                }
+            }
+        }
+    }
+
+    @Override
+    public void save(String path) throws IOException {
+        ReadWriteUtils.saveMetadata(this, path);
+    }
+
+    public static VectorAssembler load(StreamTableEnvironment env, String 
path) throws IOException {
+        return ReadWriteUtils.loadStageParam(path);
+    }
+
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return paramMap;
+    }
+
+    private static Vector assemble(Object[] objects) {
+        int offset = 0;
+        Map<Integer, Double> map = new LinkedHashMap<>(objects.length);
+        for (Object object : objects) {
+            Preconditions.checkNotNull(object, "Input column value should not 
be null.");
+            if (object instanceof Number) {
+                map.put(offset++, ((Number) object).doubleValue());
+            } else if (object instanceof Vector) {
+                offset = appendVector((Vector) object, map, offset);
+            } else {
+                throw new IllegalArgumentException("Input type has not been 
supported yet.");
+            }
+        }
+
+        if (map.size() * RATIO > offset) {
+            DenseVector assembledVector = new DenseVector(offset);
+            for (int key : map.keySet()) {
+                assembledVector.values[key] = map.get(key);
+            }
+            return assembledVector;
+        } else {
+            return convertMapToSparseVector(offset, map);
+        }
+    }
+
+    private static int appendVector(Vector vec, Map<Integer, Double> map, int 
offset) {
+        if (vec instanceof SparseVector) {
+            SparseVector sparseVector = (SparseVector) vec;
+            int[] indices = sparseVector.indices;
+            double[] values = sparseVector.values;
+            for (int i = 0; i < indices.length; ++i) {
+                map.put(offset + indices[i], values[i]);
+            }
+            offset += sparseVector.size();
+        } else {
+            DenseVector denseVector = (DenseVector) vec;
+            for (int i = 0; i < denseVector.size(); ++i) {
+                map.put(offset++, denseVector.values[i]);
+            }
+        }
+        return offset;
+    }
+
+    private static SparseVector convertMapToSparseVector(int size, 
Map<Integer, Double> map) {
+        int[] indices = new int[map.size()];
+        double[] values = new double[map.size()];
+        int offset = 0;
+        for (Map.Entry<Integer, Double> entry : map.entrySet()) {
+            indices[offset] = entry.getKey();
+            values[offset++] = entry.getValue();
+        }
+        return new SparseVector(size, indices, values);
+    }
+}
diff --git 
a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/vectorassembler/VectorAssemblerParams.java
 
b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/vectorassembler/VectorAssemblerParams.java
new file mode 100644
index 0000000..5e2cda4
--- /dev/null
+++ 
b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/vectorassembler/VectorAssemblerParams.java
@@ -0,0 +1,63 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.feature.vectorassembler;
+
+import org.apache.flink.ml.common.param.HasInputCols;
+import org.apache.flink.ml.common.param.HasOutputCol;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.param.ParamValidators;
+import org.apache.flink.ml.param.StringParam;
+
+/**
+ * Params of VectorAssembler.
+ *
+ * @param <T> The class type of this instance.
+ */
+public interface VectorAssemblerParams<T> extends HasInputCols<T>, 
HasOutputCol<T> {
+
+    String ERROR_INVALID = "error";
+    String SKIP_INVALID = "skip";
+    String KEEP_INVALID = "keep";
+
+    /**
+     * Supported options and the corresponding behavior to handle invalid 
entries is listed as
+     * follows.
+     *
+     * <ul>
+     *   <li>error: raise an exception.
+     *   <li>skip: filter out rows with bad values.
+     *   <li>keep: output bad rows with output column's value set to null.
+     * </ul>
+     */
+    Param<String> HANDLE_INVALID =
+            new StringParam(
+                    "handleInvalid",
+                    "Strategy to handle invalid entries.",
+                    ERROR_INVALID,
+                    ParamValidators.inArray(ERROR_INVALID, SKIP_INVALID, 
KEEP_INVALID));
+
+    default String getHandleInvalid() {
+        return get(HANDLE_INVALID);
+    }
+
+    default T setHandleInvalid(String value) {
+        set(HANDLE_INVALID, value);
+        return (T) this;
+    }
+}
diff --git 
a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/VectorAssemblerTest.java
 
b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/VectorAssemblerTest.java
new file mode 100644
index 0000000..193077c
--- /dev/null
+++ 
b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/VectorAssemblerTest.java
@@ -0,0 +1,178 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.feature;
+
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.ml.common.param.HasHandleInvalid;
+import org.apache.flink.ml.feature.vectorassembler.VectorAssembler;
+import org.apache.flink.ml.feature.vectorassembler.VectorAssemblerParams;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.SparseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.util.StageTestUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import 
org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.test.util.AbstractTestBase;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Test;
+
+import java.util.Arrays;
+import java.util.List;
+
+import static org.junit.Assert.assertArrayEquals;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNull;
+
+/** Tests VectorAssembler. */
+public class VectorAssemblerTest extends AbstractTestBase {
+
+    private StreamTableEnvironment tEnv;
+    private Table inputDataTable;
+
+    private static final List<Row> INPUT_DATA =
+            Arrays.asList(
+                    Row.of(
+                            0,
+                            Vectors.dense(2.1, 3.1),
+                            1.0,
+                            Vectors.sparse(5, new int[] {3}, new double[] 
{1.0})),
+                    Row.of(
+                            1,
+                            Vectors.dense(2.1, 3.1),
+                            1.0,
+                            Vectors.sparse(
+                                    5, new int[] {4, 2, 3, 1}, new double[] 
{4.0, 2.0, 3.0, 1.0})),
+                    Row.of(2, null, null, null));
+
+    private static final SparseVector EXPECTED_OUTPUT_DATA_1 =
+            Vectors.sparse(8, new int[] {0, 1, 2, 6}, new double[] {2.1, 3.1, 
1.0, 1.0});
+    private static final DenseVector EXPECTED_OUTPUT_DATA_2 =
+            Vectors.dense(2.1, 3.1, 1.0, 0.0, 1.0, 2.0, 3.0, 4.0);
+
+    @Before
+    public void before() {
+        Configuration config = new Configuration();
+        
config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, 
true);
+        StreamExecutionEnvironment env = 
StreamExecutionEnvironment.getExecutionEnvironment(config);
+        env.setParallelism(4);
+        env.enableCheckpointing(100);
+        env.setRestartStrategy(RestartStrategies.noRestart());
+        tEnv = StreamTableEnvironment.create(env);
+        DataStream<Row> dataStream = env.fromCollection(INPUT_DATA);
+        inputDataTable = tEnv.fromDataStream(dataStream).as("id", "vec", 
"num", "sparseVec");
+    }
+
+    private void verifyOutputResult(Table output, String outputCol, int 
outputSize)
+            throws Exception {
+        DataStream<Row> dataStream = tEnv.toDataStream(output);
+        List<Row> results = 
IteratorUtils.toList(dataStream.executeAndCollect());
+        assertEquals(outputSize, results.size());
+        for (Row result : results) {
+            if (result.getField(0) == (Object) 0) {
+                assertEquals(EXPECTED_OUTPUT_DATA_1, 
result.getField(outputCol));
+            } else if (result.getField(0) == (Object) 1) {
+                assertEquals(EXPECTED_OUTPUT_DATA_2, 
result.getField(outputCol));
+            } else {
+                assertNull(result.getField(outputCol));
+            }
+        }
+    }
+
+    @Test
+    public void testParam() {
+        VectorAssembler vectorAssembler = new VectorAssembler();
+        assertEquals(HasHandleInvalid.ERROR_INVALID, 
vectorAssembler.getHandleInvalid());
+        assertEquals("output", vectorAssembler.getOutputCol());
+        vectorAssembler
+                .setInputCols("vec", "num", "sparseVec")
+                .setOutputCol("assembledVec")
+                .setHandleInvalid(HasHandleInvalid.SKIP_INVALID);
+        assertArrayEquals(new String[] {"vec", "num", "sparseVec"}, 
vectorAssembler.getInputCols());
+        assertEquals(HasHandleInvalid.SKIP_INVALID, 
vectorAssembler.getHandleInvalid());
+        assertEquals("assembledVec", vectorAssembler.getOutputCol());
+    }
+
+    @Test
+    public void testKeepInvalid() throws Exception {
+        VectorAssembler vectorAssembler =
+                new VectorAssembler()
+                        .setInputCols("vec", "num", "sparseVec")
+                        .setOutputCol("assembledVec")
+                        .setHandleInvalid(VectorAssemblerParams.KEEP_INVALID);
+        Table output = vectorAssembler.transform(inputDataTable)[0];
+        assertEquals(
+                Arrays.asList("id", "vec", "num", "sparseVec", "assembledVec"),
+                output.getResolvedSchema().getColumnNames());
+        verifyOutputResult(output, vectorAssembler.getOutputCol(), 3);
+    }
+
+    @Test
+    public void testErrorInvalid() {
+        VectorAssembler vectorAssembler =
+                new VectorAssembler()
+                        .setInputCols("vec", "num", "sparseVec")
+                        .setOutputCol("assembledVec")
+                        .setHandleInvalid(HasHandleInvalid.ERROR_INVALID);
+        try {
+            Table outputTable = vectorAssembler.transform(inputDataTable)[0];
+            outputTable.execute().collect().next();
+            Assert.fail("Expected IllegalArgumentException");
+        } catch (Exception e) {
+            assertEquals(
+                    "Input column value should not be null.",
+                    
e.getCause().getCause().getCause().getCause().getCause().getMessage());
+        }
+    }
+
+    @Test
+    public void testSkipInvalid() throws Exception {
+        VectorAssembler vectorAssembler =
+                new VectorAssembler()
+                        .setInputCols("vec", "num", "sparseVec")
+                        .setOutputCol("assembledVec")
+                        .setHandleInvalid(VectorAssemblerParams.SKIP_INVALID);
+        Table output = vectorAssembler.transform(inputDataTable)[0];
+        assertEquals(
+                Arrays.asList("id", "vec", "num", "sparseVec", "assembledVec"),
+                output.getResolvedSchema().getColumnNames());
+        verifyOutputResult(output, vectorAssembler.getOutputCol(), 2);
+    }
+
+    @Test
+    public void testSaveLoadAndTransform() throws Exception {
+        VectorAssembler vectorAssembler =
+                new VectorAssembler()
+                        .setInputCols("vec", "num", "sparseVec")
+                        .setOutputCol("assembledVec")
+                        .setHandleInvalid(HasHandleInvalid.SKIP_INVALID);
+        VectorAssembler loadedVectorAssembler =
+                StageTestUtils.saveAndReload(
+                        tEnv, vectorAssembler, 
TEMPORARY_FOLDER.newFolder().getAbsolutePath());
+        Table output = loadedVectorAssembler.transform(inputDataTable)[0];
+        verifyOutputResult(output, loadedVectorAssembler.getOutputCol(), 2);
+    }
+}

Reply via email to