lindong28 commented on a change in pull request #56: URL: https://github.com/apache/flink-ml/pull/56#discussion_r840443219
########## File path: flink-ml-lib/src/main/java/org/apache/flink/ml/feature/vectorassembler/VectorAssembler.java ########## @@ -0,0 +1,179 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.feature.vectorassembler; + +import org.apache.flink.api.common.functions.MapFunction; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.java.typeutils.RowTypeInfo; +import org.apache.flink.ml.api.Transformer; +import org.apache.flink.ml.common.datastream.TableUtils; +import org.apache.flink.ml.common.param.HasHandleInvalid; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.SparseVector; +import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.util.ParamUtils; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.types.Row; +import org.apache.flink.util.Preconditions; + +import org.apache.commons.lang3.ArrayUtils; + +import java.io.IOException; +import java.util.HashMap; +import java.util.LinkedHashMap; +import java.util.Map; + +/** + * Vector assembler is a transformer that combines a given list of columns into a single vector Review comment: Other Transformer subclasses' Java docs typically start with `e.g. A model...`. And the sentence `It will combine raw features and features generated` seems a little bit redundant. Would the following Java doc be simpler? ``` /** * A feature transformer that combines a given list of input columns into a vector column. Types of * input columns must be either vector or numerical value. */ ``` ########## File path: flink-ml-lib/src/main/java/org/apache/flink/ml/feature/vectorassembler/VectorAssembler.java ########## @@ -0,0 +1,179 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.feature.vectorassembler; + +import org.apache.flink.api.common.functions.MapFunction; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.java.typeutils.RowTypeInfo; +import org.apache.flink.ml.api.Transformer; +import org.apache.flink.ml.common.datastream.TableUtils; +import org.apache.flink.ml.common.param.HasHandleInvalid; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.SparseVector; +import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.util.ParamUtils; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.types.Row; +import org.apache.flink.util.Preconditions; + +import org.apache.commons.lang3.ArrayUtils; + +import java.io.IOException; +import java.util.HashMap; +import java.util.LinkedHashMap; +import java.util.Map; + +/** + * Vector assembler is a transformer that combines a given list of columns into a single vector + * column. It will combine raw features and features generated by different feature transformers + * into a single feature vector. The input features of this transformer must be a vector feature or + * a numerical feature. + */ +public class VectorAssembler + implements Transformer<VectorAssembler>, VectorAssemblerParams<VectorAssembler> { + private final Map<Param<?>, Object> paramMap = new HashMap<>(); + private static final double RATIO = 1.5; + + public VectorAssembler() { + ParamUtils.initializeMapWithDefaultValues(paramMap, this); + } + + @Override + public Table[] transform(Table... inputs) { + Preconditions.checkArgument(inputs.length == 1); + + StreamTableEnvironment tEnv = + (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment(); + RowTypeInfo inputTypeInfo = TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema()); + RowTypeInfo outputTypeInfo = + new RowTypeInfo( + ArrayUtils.addAll( + inputTypeInfo.getFieldTypes(), TypeInformation.of(Vector.class)), + ArrayUtils.addAll(inputTypeInfo.getFieldNames(), getOutputCol())); + DataStream<Row> output = + tEnv.toDataStream(inputs[0]) + .map(new AssemblerFunc(getInputCols(), getHandleInvalid()), outputTypeInfo); + Table outputTable = tEnv.fromDataStream(output); + return new Table[] {outputTable}; + } + + private static class AssemblerFunc implements MapFunction<Row, Row> { + private final String[] inputCols; + private final String handleInvalid; + + public AssemblerFunc(String[] inputCols, String handleInvalid) { + this.inputCols = inputCols; + this.handleInvalid = handleInvalid; + } + + @Override + public Row map(Row value) { + Object[] objects = new Object[inputCols.length]; + for (int i = 0; i < objects.length; ++i) { + objects[i] = value.getField(inputCols[i]); + } + return Row.join(value, Row.of(assemble(objects, handleInvalid))); + } + } + + @Override + public void save(String path) throws IOException { + ReadWriteUtils.saveMetadata(this, path); + } + + public static VectorAssembler load(StreamTableEnvironment env, String path) throws IOException { + return ReadWriteUtils.loadStageParam(path); + } + + @Override + public Map<Param<?>, Object> getParamMap() { + return paramMap; + } + + private static Vector assemble(Object[] objects, String handleInvalid) { + int offset = 0; + Map<Integer, Double> map = new LinkedHashMap<>(objects.length); + for (Object object : objects) { + try { + Preconditions.checkNotNull(object, "VectorAssembler Error: input data is null."); + if (object instanceof Number) { + map.put(offset++, ((Number) object).doubleValue()); + } else if (object instanceof Vector) { + offset = appendVector((Vector) object, map, offset); + } else { + throw new UnsupportedOperationException( + "Vector assembler : input type has not been supported yet."); + } + } catch (Exception e) { + switch (handleInvalid) { + case HasHandleInvalid.ERROR_INVALID: + throw new RuntimeException("Vector assembler failed.", e); + case HasHandleInvalid.SKIP_INVALID: + return null; + default: + } + } + } + + if (map.size() * RATIO > offset) { + DenseVector assembledVector = new DenseVector(offset); + for (int key : map.keySet()) { + assembledVector.values[key] = map.get(key); + } + return assembledVector; + } else { + return convertMapToSparseVector(offset, map); + } + } + + private static int appendVector(Vector vec, Map<Integer, Double> map, int offset) { + if (vec instanceof SparseVector) { + SparseVector sparseVector = (SparseVector) vec; + int[] indices = sparseVector.indices; + double[] values = sparseVector.values; + for (int j = 0; j < indices.length; ++j) { + map.put(offset + indices[j], values[j]); + } + offset += sparseVector.size(); + } else if (vec instanceof DenseVector) { Review comment: Would it be simpler to just do `else {...}`? ########## File path: flink-ml-lib/src/test/java/org/apache/flink/ml/feature/VectorAssemblerTest.java ########## @@ -0,0 +1,175 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.feature; + +import org.apache.flink.api.common.restartstrategy.RestartStrategies; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.ml.common.param.HasHandleInvalid; +import org.apache.flink.ml.feature.vectorassembler.VectorAssembler; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.SparseVector; +import org.apache.flink.ml.linalg.Vectors; +import org.apache.flink.ml.util.StageTestUtils; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.table.api.DataTypes; +import org.apache.flink.table.api.Schema; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.test.util.AbstractTestBase; +import org.apache.flink.types.Row; + +import org.apache.commons.collections.IteratorUtils; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import java.util.Arrays; +import java.util.List; + +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNull; + +/** Tests VectorAssembler. */ +public class VectorAssemblerTest extends AbstractTestBase { + + private StreamTableEnvironment tEnv; + private Table inputDataTable; + + private static final List<Row> INPUT_DATA = + Arrays.asList( + Row.of( + 0, + Vectors.dense(2.1, 3.1), + 1.0, + Vectors.sparse(5, new int[] {3}, new double[] {1.0})), + Row.of( + 1, + Vectors.dense(2.1, 3.1), + 1.0, + Vectors.sparse( + 5, new int[] {4, 2, 3, 1}, new double[] {4.0, 2.0, 3.0, 1.0})), + Row.of(2, null, 1.0, null)); + + private static final SparseVector EXPECTED_DATA_1 = Review comment: nits: To be more consistent with `INPUT_DATA`, how about renaming it as `EXPECTED_OUTPUT_DATA_1`? ########## File path: flink-ml-lib/src/test/java/org/apache/flink/ml/feature/VectorAssemblerTest.java ########## @@ -0,0 +1,175 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.feature; + +import org.apache.flink.api.common.restartstrategy.RestartStrategies; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.ml.common.param.HasHandleInvalid; +import org.apache.flink.ml.feature.vectorassembler.VectorAssembler; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.SparseVector; +import org.apache.flink.ml.linalg.Vectors; +import org.apache.flink.ml.util.StageTestUtils; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.table.api.DataTypes; +import org.apache.flink.table.api.Schema; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.test.util.AbstractTestBase; +import org.apache.flink.types.Row; + +import org.apache.commons.collections.IteratorUtils; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import java.util.Arrays; +import java.util.List; + +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNull; + +/** Tests VectorAssembler. */ +public class VectorAssemblerTest extends AbstractTestBase { + + private StreamTableEnvironment tEnv; + private Table inputDataTable; + + private static final List<Row> INPUT_DATA = + Arrays.asList( + Row.of( + 0, + Vectors.dense(2.1, 3.1), + 1.0, + Vectors.sparse(5, new int[] {3}, new double[] {1.0})), + Row.of( + 1, + Vectors.dense(2.1, 3.1), + 1.0, + Vectors.sparse( + 5, new int[] {4, 2, 3, 1}, new double[] {4.0, 2.0, 3.0, 1.0})), + Row.of(2, null, 1.0, null)); + + private static final SparseVector EXPECTED_DATA_1 = + Vectors.sparse(8, new int[] {0, 1, 2, 6}, new double[] {2.1, 3.1, 1.0, 1.0}); + private static final DenseVector EXPECTED_DATA_2 = + Vectors.dense(2.1, 3.1, 1.0, 0.0, 1.0, 2.0, 3.0, 4.0); + + @Before + public void before() { + Configuration config = new Configuration(); + config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true); + StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(config); + env.setParallelism(4); + env.enableCheckpointing(100); + env.setRestartStrategy(RestartStrategies.noRestart()); + tEnv = StreamTableEnvironment.create(env); + Schema schema = + Schema.newBuilder() + .column("f0", DataTypes.INT()) + .column("f1", DataTypes.of(DenseVector.class)) + .column("f2", DataTypes.DOUBLE()) + .column("f3", DataTypes.of(SparseVector.class)) + .build(); + DataStream<Row> dataStream = env.fromCollection(INPUT_DATA); + inputDataTable = Review comment: We can just do the following without using `schema`. ``` inputDataTable = tEnv.fromDataStream(dataStream).as("id", "vec", "num", "sparseVec") ``` ########## File path: flink-ml-lib/src/main/java/org/apache/flink/ml/feature/vectorassembler/VectorAssembler.java ########## @@ -0,0 +1,179 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.feature.vectorassembler; + +import org.apache.flink.api.common.functions.MapFunction; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.java.typeutils.RowTypeInfo; +import org.apache.flink.ml.api.Transformer; +import org.apache.flink.ml.common.datastream.TableUtils; +import org.apache.flink.ml.common.param.HasHandleInvalid; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.SparseVector; +import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.util.ParamUtils; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.types.Row; +import org.apache.flink.util.Preconditions; + +import org.apache.commons.lang3.ArrayUtils; + +import java.io.IOException; +import java.util.HashMap; +import java.util.LinkedHashMap; +import java.util.Map; + +/** + * Vector assembler is a transformer that combines a given list of columns into a single vector + * column. It will combine raw features and features generated by different feature transformers + * into a single feature vector. The input features of this transformer must be a vector feature or + * a numerical feature. + */ +public class VectorAssembler + implements Transformer<VectorAssembler>, VectorAssemblerParams<VectorAssembler> { + private final Map<Param<?>, Object> paramMap = new HashMap<>(); + private static final double RATIO = 1.5; + + public VectorAssembler() { + ParamUtils.initializeMapWithDefaultValues(paramMap, this); + } + + @Override + public Table[] transform(Table... inputs) { + Preconditions.checkArgument(inputs.length == 1); + + StreamTableEnvironment tEnv = + (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment(); + RowTypeInfo inputTypeInfo = TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema()); + RowTypeInfo outputTypeInfo = + new RowTypeInfo( + ArrayUtils.addAll( + inputTypeInfo.getFieldTypes(), TypeInformation.of(Vector.class)), + ArrayUtils.addAll(inputTypeInfo.getFieldNames(), getOutputCol())); + DataStream<Row> output = + tEnv.toDataStream(inputs[0]) + .map(new AssemblerFunc(getInputCols(), getHandleInvalid()), outputTypeInfo); + Table outputTable = tEnv.fromDataStream(output); + return new Table[] {outputTable}; + } + + private static class AssemblerFunc implements MapFunction<Row, Row> { + private final String[] inputCols; + private final String handleInvalid; + + public AssemblerFunc(String[] inputCols, String handleInvalid) { + this.inputCols = inputCols; + this.handleInvalid = handleInvalid; + } + + @Override + public Row map(Row value) { + Object[] objects = new Object[inputCols.length]; + for (int i = 0; i < objects.length; ++i) { + objects[i] = value.getField(inputCols[i]); + } + return Row.join(value, Row.of(assemble(objects, handleInvalid))); + } + } + + @Override + public void save(String path) throws IOException { + ReadWriteUtils.saveMetadata(this, path); + } + + public static VectorAssembler load(StreamTableEnvironment env, String path) throws IOException { + return ReadWriteUtils.loadStageParam(path); + } + + @Override + public Map<Param<?>, Object> getParamMap() { + return paramMap; + } + + private static Vector assemble(Object[] objects, String handleInvalid) { + int offset = 0; + Map<Integer, Double> map = new LinkedHashMap<>(objects.length); + for (Object object : objects) { + try { + Preconditions.checkNotNull(object, "VectorAssembler Error: input data is null."); + if (object instanceof Number) { + map.put(offset++, ((Number) object).doubleValue()); + } else if (object instanceof Vector) { + offset = appendVector((Vector) object, map, offset); + } else { + throw new UnsupportedOperationException( + "Vector assembler : input type has not been supported yet."); + } + } catch (Exception e) { + switch (handleInvalid) { + case HasHandleInvalid.ERROR_INVALID: + throw new RuntimeException("Vector assembler failed.", e); + case HasHandleInvalid.SKIP_INVALID: + return null; Review comment: In this case, the returned table will contain a row whose outputColumn value is null. This sounds like `keep` instead of `skip`. Spark uses `keep` as the skip_invalid param value in this case. Should we add a `keep` option in `HasHandleInvalid`? ########## File path: flink-ml-lib/src/main/java/org/apache/flink/ml/feature/vectorassembler/VectorAssembler.java ########## @@ -0,0 +1,179 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.feature.vectorassembler; + +import org.apache.flink.api.common.functions.MapFunction; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.java.typeutils.RowTypeInfo; +import org.apache.flink.ml.api.Transformer; +import org.apache.flink.ml.common.datastream.TableUtils; +import org.apache.flink.ml.common.param.HasHandleInvalid; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.SparseVector; +import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.util.ParamUtils; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.types.Row; +import org.apache.flink.util.Preconditions; + +import org.apache.commons.lang3.ArrayUtils; + +import java.io.IOException; +import java.util.HashMap; +import java.util.LinkedHashMap; +import java.util.Map; + +/** + * Vector assembler is a transformer that combines a given list of columns into a single vector + * column. It will combine raw features and features generated by different feature transformers + * into a single feature vector. The input features of this transformer must be a vector feature or + * a numerical feature. + */ +public class VectorAssembler + implements Transformer<VectorAssembler>, VectorAssemblerParams<VectorAssembler> { + private final Map<Param<?>, Object> paramMap = new HashMap<>(); + private static final double RATIO = 1.5; + + public VectorAssembler() { + ParamUtils.initializeMapWithDefaultValues(paramMap, this); + } + + @Override + public Table[] transform(Table... inputs) { + Preconditions.checkArgument(inputs.length == 1); + + StreamTableEnvironment tEnv = + (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment(); + RowTypeInfo inputTypeInfo = TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema()); + RowTypeInfo outputTypeInfo = + new RowTypeInfo( + ArrayUtils.addAll( + inputTypeInfo.getFieldTypes(), TypeInformation.of(Vector.class)), + ArrayUtils.addAll(inputTypeInfo.getFieldNames(), getOutputCol())); + DataStream<Row> output = + tEnv.toDataStream(inputs[0]) + .map(new AssemblerFunc(getInputCols(), getHandleInvalid()), outputTypeInfo); + Table outputTable = tEnv.fromDataStream(output); + return new Table[] {outputTable}; + } + + private static class AssemblerFunc implements MapFunction<Row, Row> { + private final String[] inputCols; + private final String handleInvalid; + + public AssemblerFunc(String[] inputCols, String handleInvalid) { + this.inputCols = inputCols; + this.handleInvalid = handleInvalid; + } + + @Override + public Row map(Row value) { + Object[] objects = new Object[inputCols.length]; + for (int i = 0; i < objects.length; ++i) { + objects[i] = value.getField(inputCols[i]); + } + return Row.join(value, Row.of(assemble(objects, handleInvalid))); + } + } + + @Override + public void save(String path) throws IOException { + ReadWriteUtils.saveMetadata(this, path); + } + + public static VectorAssembler load(StreamTableEnvironment env, String path) throws IOException { + return ReadWriteUtils.loadStageParam(path); + } + + @Override + public Map<Param<?>, Object> getParamMap() { + return paramMap; + } + + private static Vector assemble(Object[] objects, String handleInvalid) { + int offset = 0; + Map<Integer, Double> map = new LinkedHashMap<>(objects.length); + for (Object object : objects) { + try { + Preconditions.checkNotNull(object, "VectorAssembler Error: input data is null."); + if (object instanceof Number) { + map.put(offset++, ((Number) object).doubleValue()); + } else if (object instanceof Vector) { + offset = appendVector((Vector) object, map, offset); + } else { + throw new UnsupportedOperationException( Review comment: Should it be `IllegalArgumentException`? And we probably don't need to explicitly specify `Vector assembler` in the error message since the stracktrace should already contain this information. ########## File path: flink-ml-lib/src/main/java/org/apache/flink/ml/feature/vectorassembler/VectorAssembler.java ########## @@ -0,0 +1,179 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.feature.vectorassembler; + +import org.apache.flink.api.common.functions.MapFunction; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.java.typeutils.RowTypeInfo; +import org.apache.flink.ml.api.Transformer; +import org.apache.flink.ml.common.datastream.TableUtils; +import org.apache.flink.ml.common.param.HasHandleInvalid; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.SparseVector; +import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.util.ParamUtils; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.types.Row; +import org.apache.flink.util.Preconditions; + +import org.apache.commons.lang3.ArrayUtils; + +import java.io.IOException; +import java.util.HashMap; +import java.util.LinkedHashMap; +import java.util.Map; + +/** + * Vector assembler is a transformer that combines a given list of columns into a single vector + * column. It will combine raw features and features generated by different feature transformers + * into a single feature vector. The input features of this transformer must be a vector feature or + * a numerical feature. + */ +public class VectorAssembler + implements Transformer<VectorAssembler>, VectorAssemblerParams<VectorAssembler> { + private final Map<Param<?>, Object> paramMap = new HashMap<>(); + private static final double RATIO = 1.5; + + public VectorAssembler() { + ParamUtils.initializeMapWithDefaultValues(paramMap, this); + } + + @Override + public Table[] transform(Table... inputs) { + Preconditions.checkArgument(inputs.length == 1); + + StreamTableEnvironment tEnv = + (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment(); + RowTypeInfo inputTypeInfo = TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema()); + RowTypeInfo outputTypeInfo = + new RowTypeInfo( + ArrayUtils.addAll( + inputTypeInfo.getFieldTypes(), TypeInformation.of(Vector.class)), + ArrayUtils.addAll(inputTypeInfo.getFieldNames(), getOutputCol())); + DataStream<Row> output = + tEnv.toDataStream(inputs[0]) + .map(new AssemblerFunc(getInputCols(), getHandleInvalid()), outputTypeInfo); + Table outputTable = tEnv.fromDataStream(output); + return new Table[] {outputTable}; + } + + private static class AssemblerFunc implements MapFunction<Row, Row> { + private final String[] inputCols; + private final String handleInvalid; + + public AssemblerFunc(String[] inputCols, String handleInvalid) { + this.inputCols = inputCols; + this.handleInvalid = handleInvalid; + } + + @Override + public Row map(Row value) { + Object[] objects = new Object[inputCols.length]; + for (int i = 0; i < objects.length; ++i) { + objects[i] = value.getField(inputCols[i]); + } + return Row.join(value, Row.of(assemble(objects, handleInvalid))); + } + } + + @Override + public void save(String path) throws IOException { + ReadWriteUtils.saveMetadata(this, path); + } + + public static VectorAssembler load(StreamTableEnvironment env, String path) throws IOException { + return ReadWriteUtils.loadStageParam(path); + } + + @Override + public Map<Param<?>, Object> getParamMap() { + return paramMap; + } + + private static Vector assemble(Object[] objects, String handleInvalid) { + int offset = 0; + Map<Integer, Double> map = new LinkedHashMap<>(objects.length); + for (Object object : objects) { + try { + Preconditions.checkNotNull(object, "VectorAssembler Error: input data is null."); + if (object instanceof Number) { + map.put(offset++, ((Number) object).doubleValue()); + } else if (object instanceof Vector) { + offset = appendVector((Vector) object, map, offset); + } else { + throw new UnsupportedOperationException( + "Vector assembler : input type has not been supported yet."); + } + } catch (Exception e) { + switch (handleInvalid) { + case HasHandleInvalid.ERROR_INVALID: + throw new RuntimeException("Vector assembler failed.", e); + case HasHandleInvalid.SKIP_INVALID: + return null; + default: + } + } + } + + if (map.size() * RATIO > offset) { + DenseVector assembledVector = new DenseVector(offset); + for (int key : map.keySet()) { + assembledVector.values[key] = map.get(key); + } + return assembledVector; + } else { + return convertMapToSparseVector(offset, map); + } + } + + private static int appendVector(Vector vec, Map<Integer, Double> map, int offset) { + if (vec instanceof SparseVector) { + SparseVector sparseVector = (SparseVector) vec; + int[] indices = sparseVector.indices; + double[] values = sparseVector.values; + for (int j = 0; j < indices.length; ++j) { + map.put(offset + indices[j], values[j]); + } + offset += sparseVector.size(); + } else if (vec instanceof DenseVector) { + DenseVector denseVector = (DenseVector) vec; + for (int j = 0; j < denseVector.size(); ++j) { + map.put(offset++, denseVector.values[j]); + } + } + return offset; + } + + private static SparseVector convertMapToSparseVector(int size, Map<Integer, Double> map) { + int nnz = map.size(); Review comment: It is not obvious what `nnz` means. Would it be more self-explanatory to rename it as `map_size`? It could be even simpler to do the following: ``` int[] indices = new int[map.size()]; double[] values = new double[map.size()]; ``` ########## File path: flink-ml-lib/src/main/java/org/apache/flink/ml/feature/vectorassembler/VectorAssembler.java ########## @@ -0,0 +1,179 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.feature.vectorassembler; + +import org.apache.flink.api.common.functions.MapFunction; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.java.typeutils.RowTypeInfo; +import org.apache.flink.ml.api.Transformer; +import org.apache.flink.ml.common.datastream.TableUtils; +import org.apache.flink.ml.common.param.HasHandleInvalid; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.SparseVector; +import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.util.ParamUtils; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.types.Row; +import org.apache.flink.util.Preconditions; + +import org.apache.commons.lang3.ArrayUtils; + +import java.io.IOException; +import java.util.HashMap; +import java.util.LinkedHashMap; +import java.util.Map; + +/** + * Vector assembler is a transformer that combines a given list of columns into a single vector + * column. It will combine raw features and features generated by different feature transformers + * into a single feature vector. The input features of this transformer must be a vector feature or + * a numerical feature. + */ +public class VectorAssembler + implements Transformer<VectorAssembler>, VectorAssemblerParams<VectorAssembler> { + private final Map<Param<?>, Object> paramMap = new HashMap<>(); + private static final double RATIO = 1.5; + + public VectorAssembler() { + ParamUtils.initializeMapWithDefaultValues(paramMap, this); + } + + @Override + public Table[] transform(Table... inputs) { + Preconditions.checkArgument(inputs.length == 1); + + StreamTableEnvironment tEnv = + (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment(); + RowTypeInfo inputTypeInfo = TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema()); + RowTypeInfo outputTypeInfo = + new RowTypeInfo( + ArrayUtils.addAll( + inputTypeInfo.getFieldTypes(), TypeInformation.of(Vector.class)), + ArrayUtils.addAll(inputTypeInfo.getFieldNames(), getOutputCol())); + DataStream<Row> output = + tEnv.toDataStream(inputs[0]) + .map(new AssemblerFunc(getInputCols(), getHandleInvalid()), outputTypeInfo); + Table outputTable = tEnv.fromDataStream(output); + return new Table[] {outputTable}; + } + + private static class AssemblerFunc implements MapFunction<Row, Row> { + private final String[] inputCols; + private final String handleInvalid; + + public AssemblerFunc(String[] inputCols, String handleInvalid) { + this.inputCols = inputCols; + this.handleInvalid = handleInvalid; + } + + @Override + public Row map(Row value) { + Object[] objects = new Object[inputCols.length]; + for (int i = 0; i < objects.length; ++i) { + objects[i] = value.getField(inputCols[i]); + } + return Row.join(value, Row.of(assemble(objects, handleInvalid))); + } + } + + @Override + public void save(String path) throws IOException { + ReadWriteUtils.saveMetadata(this, path); + } + + public static VectorAssembler load(StreamTableEnvironment env, String path) throws IOException { + return ReadWriteUtils.loadStageParam(path); + } + + @Override + public Map<Param<?>, Object> getParamMap() { + return paramMap; + } + + private static Vector assemble(Object[] objects, String handleInvalid) { + int offset = 0; + Map<Integer, Double> map = new LinkedHashMap<>(objects.length); + for (Object object : objects) { + try { + Preconditions.checkNotNull(object, "VectorAssembler Error: input data is null."); + if (object instanceof Number) { + map.put(offset++, ((Number) object).doubleValue()); + } else if (object instanceof Vector) { + offset = appendVector((Vector) object, map, offset); + } else { + throw new UnsupportedOperationException( + "Vector assembler : input type has not been supported yet."); + } + } catch (Exception e) { + switch (handleInvalid) { + case HasHandleInvalid.ERROR_INVALID: + throw new RuntimeException("Vector assembler failed.", e); + case HasHandleInvalid.SKIP_INVALID: + return null; + default: + } + } + } + + if (map.size() * RATIO > offset) { + DenseVector assembledVector = new DenseVector(offset); + for (int key : map.keySet()) { + assembledVector.values[key] = map.get(key); + } + return assembledVector; + } else { + return convertMapToSparseVector(offset, map); + } + } + + private static int appendVector(Vector vec, Map<Integer, Double> map, int offset) { + if (vec instanceof SparseVector) { + SparseVector sparseVector = (SparseVector) vec; + int[] indices = sparseVector.indices; + double[] values = sparseVector.values; + for (int j = 0; j < indices.length; ++j) { + map.put(offset + indices[j], values[j]); + } + offset += sparseVector.size(); + } else if (vec instanceof DenseVector) { + DenseVector denseVector = (DenseVector) vec; + for (int j = 0; j < denseVector.size(); ++j) { Review comment: nits: Could we use `for (int i = 0; i < indices.length; i++)` to be consistent with other for loop code style? ########## File path: flink-ml-lib/src/test/java/org/apache/flink/ml/feature/VectorAssemblerTest.java ########## @@ -0,0 +1,175 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.feature; + +import org.apache.flink.api.common.restartstrategy.RestartStrategies; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.ml.common.param.HasHandleInvalid; +import org.apache.flink.ml.feature.vectorassembler.VectorAssembler; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.SparseVector; +import org.apache.flink.ml.linalg.Vectors; +import org.apache.flink.ml.util.StageTestUtils; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.table.api.DataTypes; +import org.apache.flink.table.api.Schema; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.test.util.AbstractTestBase; +import org.apache.flink.types.Row; + +import org.apache.commons.collections.IteratorUtils; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import java.util.Arrays; +import java.util.List; + +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNull; + +/** Tests VectorAssembler. */ +public class VectorAssemblerTest extends AbstractTestBase { + + private StreamTableEnvironment tEnv; + private Table inputDataTable; + + private static final List<Row> INPUT_DATA = + Arrays.asList( + Row.of( + 0, + Vectors.dense(2.1, 3.1), + 1.0, + Vectors.sparse(5, new int[] {3}, new double[] {1.0})), + Row.of( + 1, + Vectors.dense(2.1, 3.1), + 1.0, + Vectors.sparse( + 5, new int[] {4, 2, 3, 1}, new double[] {4.0, 2.0, 3.0, 1.0})), + Row.of(2, null, 1.0, null)); + + private static final SparseVector EXPECTED_DATA_1 = + Vectors.sparse(8, new int[] {0, 1, 2, 6}, new double[] {2.1, 3.1, 1.0, 1.0}); + private static final DenseVector EXPECTED_DATA_2 = + Vectors.dense(2.1, 3.1, 1.0, 0.0, 1.0, 2.0, 3.0, 4.0); + + @Before + public void before() { + Configuration config = new Configuration(); + config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true); + StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(config); + env.setParallelism(4); + env.enableCheckpointing(100); + env.setRestartStrategy(RestartStrategies.noRestart()); + tEnv = StreamTableEnvironment.create(env); + Schema schema = + Schema.newBuilder() + .column("f0", DataTypes.INT()) + .column("f1", DataTypes.of(DenseVector.class)) + .column("f2", DataTypes.DOUBLE()) + .column("f3", DataTypes.of(SparseVector.class)) + .build(); + DataStream<Row> dataStream = env.fromCollection(INPUT_DATA); + inputDataTable = + tEnv.fromDataStream(dataStream, schema).as("id", "vec", "num", "sparseVec"); + } + + private void verifyOutputResult(Table output, String outputCol) throws Exception { + DataStream<Row> dataStream = tEnv.toDataStream(output); + List<Row> results = IteratorUtils.toList(dataStream.executeAndCollect()); + assertEquals(3, results.size()); + for (Row result : results) { + if (result.getField(0) == (Object) 0) { + assertEquals(EXPECTED_DATA_1, result.getField(outputCol)); + } else if (result.getField(0) == (Object) 1) { + assertEquals(EXPECTED_DATA_2, result.getField(outputCol)); + } else { + assertNull(result.getField(outputCol)); + } + } + } + + @Test + public void testParam() { + VectorAssembler vectorAssembler = new VectorAssembler(); + assertEquals(HasHandleInvalid.ERROR_INVALID, vectorAssembler.getHandleInvalid()); + assertEquals("output", vectorAssembler.getOutputCol()); + vectorAssembler + .setInputCols("vec", "num", "sparseVec") + .setOutputCol("assembledVec") + .setHandleInvalid(HasHandleInvalid.SKIP_INVALID); + assertArrayEquals(new String[] {"vec", "num", "sparseVec"}, vectorAssembler.getInputCols()); + assertEquals(HasHandleInvalid.SKIP_INVALID, vectorAssembler.getHandleInvalid()); + assertEquals("assembledVec", vectorAssembler.getOutputCol()); + } + + @Test + public void testTransform() throws Exception { + VectorAssembler vectorAssembler = + new VectorAssembler() + .setInputCols("vec", "num", "sparseVec") + .setOutputCol("assembledVec") + .setHandleInvalid(HasHandleInvalid.SKIP_INVALID); + Table output = vectorAssembler.transform(inputDataTable)[0]; Review comment: Could we also test the schema of the output table, similar to what we did in `KMeansTest::testFitAndPredict()`? ########## File path: flink-ml-lib/src/main/java/org/apache/flink/ml/feature/vectorassembler/VectorAssembler.java ########## @@ -0,0 +1,179 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.feature.vectorassembler; + +import org.apache.flink.api.common.functions.MapFunction; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.java.typeutils.RowTypeInfo; +import org.apache.flink.ml.api.Transformer; +import org.apache.flink.ml.common.datastream.TableUtils; +import org.apache.flink.ml.common.param.HasHandleInvalid; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.SparseVector; +import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.util.ParamUtils; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.types.Row; +import org.apache.flink.util.Preconditions; + +import org.apache.commons.lang3.ArrayUtils; + +import java.io.IOException; +import java.util.HashMap; +import java.util.LinkedHashMap; +import java.util.Map; + +/** + * Vector assembler is a transformer that combines a given list of columns into a single vector + * column. It will combine raw features and features generated by different feature transformers + * into a single feature vector. The input features of this transformer must be a vector feature or + * a numerical feature. + */ +public class VectorAssembler + implements Transformer<VectorAssembler>, VectorAssemblerParams<VectorAssembler> { + private final Map<Param<?>, Object> paramMap = new HashMap<>(); + private static final double RATIO = 1.5; + + public VectorAssembler() { + ParamUtils.initializeMapWithDefaultValues(paramMap, this); + } + + @Override + public Table[] transform(Table... inputs) { + Preconditions.checkArgument(inputs.length == 1); + + StreamTableEnvironment tEnv = + (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment(); + RowTypeInfo inputTypeInfo = TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema()); + RowTypeInfo outputTypeInfo = + new RowTypeInfo( + ArrayUtils.addAll( + inputTypeInfo.getFieldTypes(), TypeInformation.of(Vector.class)), + ArrayUtils.addAll(inputTypeInfo.getFieldNames(), getOutputCol())); + DataStream<Row> output = + tEnv.toDataStream(inputs[0]) + .map(new AssemblerFunc(getInputCols(), getHandleInvalid()), outputTypeInfo); + Table outputTable = tEnv.fromDataStream(output); + return new Table[] {outputTable}; + } + + private static class AssemblerFunc implements MapFunction<Row, Row> { + private final String[] inputCols; + private final String handleInvalid; + + public AssemblerFunc(String[] inputCols, String handleInvalid) { + this.inputCols = inputCols; + this.handleInvalid = handleInvalid; + } + + @Override + public Row map(Row value) { + Object[] objects = new Object[inputCols.length]; + for (int i = 0; i < objects.length; ++i) { + objects[i] = value.getField(inputCols[i]); + } + return Row.join(value, Row.of(assemble(objects, handleInvalid))); + } + } + + @Override + public void save(String path) throws IOException { + ReadWriteUtils.saveMetadata(this, path); + } + + public static VectorAssembler load(StreamTableEnvironment env, String path) throws IOException { + return ReadWriteUtils.loadStageParam(path); + } + + @Override + public Map<Param<?>, Object> getParamMap() { + return paramMap; + } + + private static Vector assemble(Object[] objects, String handleInvalid) { + int offset = 0; + Map<Integer, Double> map = new LinkedHashMap<>(objects.length); + for (Object object : objects) { + try { + Preconditions.checkNotNull(object, "VectorAssembler Error: input data is null."); + if (object instanceof Number) { + map.put(offset++, ((Number) object).doubleValue()); + } else if (object instanceof Vector) { + offset = appendVector((Vector) object, map, offset); + } else { + throw new UnsupportedOperationException( + "Vector assembler : input type has not been supported yet."); + } + } catch (Exception e) { + switch (handleInvalid) { + case HasHandleInvalid.ERROR_INVALID: + throw new RuntimeException("Vector assembler failed.", e); Review comment: Would it be simpler to just do `throw e` here? ########## File path: flink-ml-lib/src/main/java/org/apache/flink/ml/feature/vectorassembler/VectorAssembler.java ########## @@ -0,0 +1,179 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.feature.vectorassembler; + +import org.apache.flink.api.common.functions.MapFunction; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.java.typeutils.RowTypeInfo; +import org.apache.flink.ml.api.Transformer; +import org.apache.flink.ml.common.datastream.TableUtils; +import org.apache.flink.ml.common.param.HasHandleInvalid; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.SparseVector; +import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.util.ParamUtils; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.types.Row; +import org.apache.flink.util.Preconditions; + +import org.apache.commons.lang3.ArrayUtils; + +import java.io.IOException; +import java.util.HashMap; +import java.util.LinkedHashMap; +import java.util.Map; + +/** + * Vector assembler is a transformer that combines a given list of columns into a single vector + * column. It will combine raw features and features generated by different feature transformers + * into a single feature vector. The input features of this transformer must be a vector feature or + * a numerical feature. + */ +public class VectorAssembler + implements Transformer<VectorAssembler>, VectorAssemblerParams<VectorAssembler> { + private final Map<Param<?>, Object> paramMap = new HashMap<>(); + private static final double RATIO = 1.5; + + public VectorAssembler() { + ParamUtils.initializeMapWithDefaultValues(paramMap, this); + } + + @Override + public Table[] transform(Table... inputs) { + Preconditions.checkArgument(inputs.length == 1); + + StreamTableEnvironment tEnv = + (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment(); + RowTypeInfo inputTypeInfo = TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema()); + RowTypeInfo outputTypeInfo = + new RowTypeInfo( + ArrayUtils.addAll( + inputTypeInfo.getFieldTypes(), TypeInformation.of(Vector.class)), + ArrayUtils.addAll(inputTypeInfo.getFieldNames(), getOutputCol())); + DataStream<Row> output = + tEnv.toDataStream(inputs[0]) + .map(new AssemblerFunc(getInputCols(), getHandleInvalid()), outputTypeInfo); + Table outputTable = tEnv.fromDataStream(output); + return new Table[] {outputTable}; + } + + private static class AssemblerFunc implements MapFunction<Row, Row> { + private final String[] inputCols; + private final String handleInvalid; + + public AssemblerFunc(String[] inputCols, String handleInvalid) { + this.inputCols = inputCols; + this.handleInvalid = handleInvalid; + } + + @Override + public Row map(Row value) { + Object[] objects = new Object[inputCols.length]; + for (int i = 0; i < objects.length; ++i) { + objects[i] = value.getField(inputCols[i]); + } + return Row.join(value, Row.of(assemble(objects, handleInvalid))); + } + } + + @Override + public void save(String path) throws IOException { + ReadWriteUtils.saveMetadata(this, path); + } + + public static VectorAssembler load(StreamTableEnvironment env, String path) throws IOException { + return ReadWriteUtils.loadStageParam(path); + } + + @Override + public Map<Param<?>, Object> getParamMap() { + return paramMap; + } + + private static Vector assemble(Object[] objects, String handleInvalid) { + int offset = 0; + Map<Integer, Double> map = new LinkedHashMap<>(objects.length); + for (Object object : objects) { + try { + Preconditions.checkNotNull(object, "VectorAssembler Error: input data is null."); Review comment: `VectorAssembler Error` seems redundant given that the stacktrace should also contain `VectorAssember`. How about changing the error to `input column value should not be null`? ########## File path: flink-ml-lib/src/main/java/org/apache/flink/ml/feature/vectorassembler/VectorAssembler.java ########## @@ -0,0 +1,179 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.feature.vectorassembler; + +import org.apache.flink.api.common.functions.MapFunction; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.java.typeutils.RowTypeInfo; +import org.apache.flink.ml.api.Transformer; +import org.apache.flink.ml.common.datastream.TableUtils; +import org.apache.flink.ml.common.param.HasHandleInvalid; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.SparseVector; +import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.util.ParamUtils; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.types.Row; +import org.apache.flink.util.Preconditions; + +import org.apache.commons.lang3.ArrayUtils; + +import java.io.IOException; +import java.util.HashMap; +import java.util.LinkedHashMap; +import java.util.Map; + +/** + * Vector assembler is a transformer that combines a given list of columns into a single vector + * column. It will combine raw features and features generated by different feature transformers + * into a single feature vector. The input features of this transformer must be a vector feature or + * a numerical feature. + */ +public class VectorAssembler + implements Transformer<VectorAssembler>, VectorAssemblerParams<VectorAssembler> { + private final Map<Param<?>, Object> paramMap = new HashMap<>(); + private static final double RATIO = 1.5; + + public VectorAssembler() { + ParamUtils.initializeMapWithDefaultValues(paramMap, this); + } + + @Override + public Table[] transform(Table... inputs) { + Preconditions.checkArgument(inputs.length == 1); + + StreamTableEnvironment tEnv = + (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment(); + RowTypeInfo inputTypeInfo = TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema()); + RowTypeInfo outputTypeInfo = + new RowTypeInfo( + ArrayUtils.addAll( + inputTypeInfo.getFieldTypes(), TypeInformation.of(Vector.class)), + ArrayUtils.addAll(inputTypeInfo.getFieldNames(), getOutputCol())); + DataStream<Row> output = + tEnv.toDataStream(inputs[0]) + .map(new AssemblerFunc(getInputCols(), getHandleInvalid()), outputTypeInfo); + Table outputTable = tEnv.fromDataStream(output); + return new Table[] {outputTable}; + } + + private static class AssemblerFunc implements MapFunction<Row, Row> { + private final String[] inputCols; + private final String handleInvalid; + + public AssemblerFunc(String[] inputCols, String handleInvalid) { + this.inputCols = inputCols; + this.handleInvalid = handleInvalid; + } + + @Override + public Row map(Row value) { + Object[] objects = new Object[inputCols.length]; + for (int i = 0; i < objects.length; ++i) { + objects[i] = value.getField(inputCols[i]); + } + return Row.join(value, Row.of(assemble(objects, handleInvalid))); + } + } + + @Override + public void save(String path) throws IOException { + ReadWriteUtils.saveMetadata(this, path); + } + + public static VectorAssembler load(StreamTableEnvironment env, String path) throws IOException { + return ReadWriteUtils.loadStageParam(path); + } + + @Override + public Map<Param<?>, Object> getParamMap() { + return paramMap; + } + + private static Vector assemble(Object[] objects, String handleInvalid) { + int offset = 0; + Map<Integer, Double> map = new LinkedHashMap<>(objects.length); + for (Object object : objects) { + try { + Preconditions.checkNotNull(object, "VectorAssembler Error: input data is null."); + if (object instanceof Number) { + map.put(offset++, ((Number) object).doubleValue()); + } else if (object instanceof Vector) { + offset = appendVector((Vector) object, map, offset); + } else { + throw new UnsupportedOperationException( + "Vector assembler : input type has not been supported yet."); + } + } catch (Exception e) { + switch (handleInvalid) { + case HasHandleInvalid.ERROR_INVALID: + throw new RuntimeException("Vector assembler failed.", e); + case HasHandleInvalid.SKIP_INVALID: + return null; + default: + } + } + } + + if (map.size() * RATIO > offset) { + DenseVector assembledVector = new DenseVector(offset); + for (int key : map.keySet()) { + assembledVector.values[key] = map.get(key); + } + return assembledVector; + } else { + return convertMapToSparseVector(offset, map); + } + } + + private static int appendVector(Vector vec, Map<Integer, Double> map, int offset) { + if (vec instanceof SparseVector) { + SparseVector sparseVector = (SparseVector) vec; + int[] indices = sparseVector.indices; + double[] values = sparseVector.values; + for (int j = 0; j < indices.length; ++j) { Review comment: nits: Could we use `for (int i = 0; i < indices.length; i++)` to be consistent with other for loop code style? -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org