This is an automated email from the ASF dual-hosted git repository. gaoyunhaii pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/flink-ml.git
The following commit(s) were added to refs/heads/master by this push: new 79d0c24 [FLINK-24955] Add Estimator and Transformer for One Hot Encoder 79d0c24 is described below commit 79d0c247e0cc900c9a16e045a99000423003780b Author: Yunfeng Zhou <yuri.zhouyunf...@outlook.com> AuthorDate: Thu Dec 2 15:32:17 2021 +0800 [FLINK-24955] Add Estimator and Transformer for One Hot Encoder This closes #37. --- .../org/apache/flink/ml/linalg/SparseVector.java | 168 ++++++++++++ .../java/org/apache/flink/ml/linalg/Vectors.java | 5 + .../ml/linalg/typeinfo/DenseVectorSerializer.java | 8 +- .../ml/linalg/typeinfo/DenseVectorTypeInfo.java | 2 +- .../ml/linalg/typeinfo/SparseVectorSerializer.java | 151 +++++++++++ ...ctorTypeInfo.java => SparseVectorTypeInfo.java} | 66 +++-- .../typeinfo/SparseVectorTypeInfoFactory.java | 40 +++ .../org/apache/flink/ml/param/ParamValidators.java | 5 + .../java/org/apache/flink/ml/api/StageTest.java | 5 + .../apache/flink/ml/linalg/SparseVectorTest.java | 132 +++++++++ .../flink/ml/common/param/HasHandleInvalid.java | 56 ++++ .../apache/flink/ml/common/param/HasInputCols.java | 23 +- .../flink/ml/common/param/HasOutputCols.java | 23 +- .../ml/feature/onehotencoder/OneHotEncoder.java | 148 +++++++++++ .../feature/onehotencoder/OneHotEncoderModel.java | 190 +++++++++++++ .../onehotencoder/OneHotEncoderModelData.java | 109 ++++++++ .../feature/onehotencoder/OneHotEncoderParams.java | 28 +- .../apache/flink/ml/feature/OneHotEncoderTest.java | 294 +++++++++++++++++++++ 18 files changed, 1393 insertions(+), 60 deletions(-) diff --git a/flink-ml-core/src/main/java/org/apache/flink/ml/linalg/SparseVector.java b/flink-ml-core/src/main/java/org/apache/flink/ml/linalg/SparseVector.java new file mode 100644 index 0000000..4e683a4 --- /dev/null +++ b/flink-ml-core/src/main/java/org/apache/flink/ml/linalg/SparseVector.java @@ -0,0 +1,168 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.linalg; + +import org.apache.flink.api.common.typeinfo.TypeInfo; +import org.apache.flink.ml.linalg.typeinfo.SparseVectorTypeInfoFactory; +import org.apache.flink.util.Preconditions; + +import java.util.Arrays; +import java.util.Objects; + +/** A sparse vector of double values. */ +@TypeInfo(SparseVectorTypeInfoFactory.class) +public class SparseVector implements Vector { + public final int n; + public final int[] indices; + public final double[] values; + + public SparseVector(int n, int[] indices, double[] values) { + this.n = n; + this.indices = indices; + this.values = values; + if (!isIndicesSorted()) { + sortIndices(); + } + validateSortedData(); + } + + @Override + public int size() { + return n; + } + + @Override + public double get(int i) { + int pos = Arrays.binarySearch(indices, i); + if (pos >= 0) { + return values[pos]; + } + return 0.; + } + + @Override + public double[] toArray() { + double[] result = new double[n]; + for (int i = 0; i < indices.length; i++) { + result[indices[i]] = values[i]; + } + return result; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + SparseVector that = (SparseVector) o; + return n == that.n + && Arrays.equals(indices, that.indices) + && Arrays.equals(values, that.values); + } + + @Override + public int hashCode() { + int result = Objects.hash(n); + result = 31 * result + Arrays.hashCode(indices); + result = 31 * result + Arrays.hashCode(values); + return result; + } + + /** + * Checks whether input data is validate. + * + * <p>This function does the following checks: + * + * <ul> + * <li>The indices array and values array are of the same size. + * <li>vector indices are in valid range. + * <li>vector indices are unique. + * </ul> + * + * <p>This function works as expected only when indices are sorted. + */ + private void validateSortedData() { + Preconditions.checkArgument( + indices.length == values.length, + "Indices size and values size should be the same."); + if (this.indices.length > 0) { + Preconditions.checkArgument( + this.indices[0] >= 0 && this.indices[this.indices.length - 1] < this.n, + "Index out of bound."); + } + for (int i = 1; i < this.indices.length; i++) { + Preconditions.checkArgument( + this.indices[i] > this.indices[i - 1], "Indices duplicated."); + } + } + + private boolean isIndicesSorted() { + for (int i = 1; i < this.indices.length; i++) { + if (this.indices[i] < this.indices[i - 1]) { + return false; + } + } + return true; + } + + /** Sorts the indices and values. */ + private void sortIndices() { + sortImpl(this.indices, this.values, 0, this.indices.length - 1); + } + + /** Sorts the indices and values using quick sort. */ + private static void sortImpl(int[] indices, double[] values, int low, int high) { + int pivotPos = (low + high) / 2; + int pivot = indices[pivotPos]; + swapIndexAndValue(indices, values, pivotPos, high); + + int pos = low - 1; + for (int i = low; i <= high; i++) { + if (indices[i] <= pivot) { + pos++; + swapIndexAndValue(indices, values, pos, i); + } + } + if (high > pos + 1) { + sortImpl(indices, values, pos + 1, high); + } + if (pos - 1 > low) { + sortImpl(indices, values, low, pos - 1); + } + } + + private static void swapIndexAndValue(int[] indices, double[] values, int index1, int index2) { + int tempIndex = indices[index1]; + indices[index1] = indices[index2]; + indices[index2] = tempIndex; + double tempValue = values[index1]; + values[index1] = values[index2]; + values[index2] = tempValue; + } + + @Override + public String toString() { + String sbr = + "(" + n + ", " + Arrays.toString(indices) + ", " + Arrays.toString(values) + ")"; + return sbr; + } +} diff --git a/flink-ml-core/src/main/java/org/apache/flink/ml/linalg/Vectors.java b/flink-ml-core/src/main/java/org/apache/flink/ml/linalg/Vectors.java index a058755..424b27f 100644 --- a/flink-ml-core/src/main/java/org/apache/flink/ml/linalg/Vectors.java +++ b/flink-ml-core/src/main/java/org/apache/flink/ml/linalg/Vectors.java @@ -25,4 +25,9 @@ public class Vectors { public static DenseVector dense(double... values) { return new DenseVector(values); } + + /** Creates a sparse vector from its values. */ + public static SparseVector sparse(int size, int[] indices, double[] values) { + return new SparseVector(size, indices, values); + } } diff --git a/flink-ml-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/DenseVectorSerializer.java b/flink-ml-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/DenseVectorSerializer.java index 3cbde53..153a20f 100644 --- a/flink-ml-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/DenseVectorSerializer.java +++ b/flink-ml-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/DenseVectorSerializer.java @@ -29,14 +29,14 @@ import org.apache.flink.ml.linalg.DenseVector; import java.io.IOException; import java.util.Arrays; -/** Specialized serializer for {@code DenseVector}. */ +/** Specialized serializer for {@link DenseVector}. */ public final class DenseVectorSerializer extends TypeSerializerSingleton<DenseVector> { private static final long serialVersionUID = 1L; private static final double[] EMPTY = new double[0]; - private static final DenseVectorSerializer INSTANCE = new DenseVectorSerializer(); + public static final DenseVectorSerializer INSTANCE = new DenseVectorSerializer(); @Override public boolean isImmutableType() { @@ -84,9 +84,7 @@ public final class DenseVectorSerializer extends TypeSerializerSingleton<DenseVe public DenseVector deserialize(DataInputView source) throws IOException { int len = source.readInt(); double[] values = new double[len]; - for (int i = 0; i < len; i++) { - values[i] = source.readDouble(); - } + readDoubleArray(values, source, len); return new DenseVector(values); } diff --git a/flink-ml-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/DenseVectorTypeInfo.java b/flink-ml-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/DenseVectorTypeInfo.java index 0239e17..765cacb 100644 --- a/flink-ml-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/DenseVectorTypeInfo.java +++ b/flink-ml-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/DenseVectorTypeInfo.java @@ -64,7 +64,7 @@ public class DenseVectorTypeInfo extends TypeInformation<DenseVector> { @Override @SuppressWarnings("unchecked") public TypeSerializer<DenseVector> createSerializer(ExecutionConfig executionConfig) { - return new DenseVectorSerializer(); + return DenseVectorSerializer.INSTANCE; } // -------------------------------------------------------------------------------------------- diff --git a/flink-ml-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/SparseVectorSerializer.java b/flink-ml-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/SparseVectorSerializer.java new file mode 100644 index 0000000..2c922a9 --- /dev/null +++ b/flink-ml-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/SparseVectorSerializer.java @@ -0,0 +1,151 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.flink.ml.linalg.typeinfo; + +import org.apache.flink.api.common.typeutils.SimpleTypeSerializerSnapshot; +import org.apache.flink.api.common.typeutils.TypeSerializerSnapshot; +import org.apache.flink.api.common.typeutils.base.TypeSerializerSingleton; +import org.apache.flink.core.memory.DataInputView; +import org.apache.flink.core.memory.DataOutputView; +import org.apache.flink.ml.linalg.SparseVector; + +import java.io.IOException; +import java.util.Arrays; + +/** Specialized serializer for {@link SparseVector}. */ +public final class SparseVectorSerializer extends TypeSerializerSingleton<SparseVector> { + + private static final long serialVersionUID = 1L; + + private static final double[] EMPTY_DOUBLE_ARRAY = new double[0]; + + private static final int[] EMPTY_INT_ARRAY = new int[0]; + + public static final SparseVectorSerializer INSTANCE = new SparseVectorSerializer(); + + @Override + public boolean isImmutableType() { + return false; + } + + @Override + public SparseVector createInstance() { + return new SparseVector(0, EMPTY_INT_ARRAY, EMPTY_DOUBLE_ARRAY); + } + + @Override + public SparseVector copy(SparseVector from) { + return new SparseVector( + from.n, + Arrays.copyOf(from.indices, from.indices.length), + Arrays.copyOf(from.values, from.values.length)); + } + + @Override + public SparseVector copy(SparseVector from, SparseVector reuse) { + if (from.values.length == reuse.values.length && from.n == reuse.n) { + System.arraycopy(from.values, 0, reuse.values, 0, from.values.length); + System.arraycopy(from.indices, 0, reuse.indices, 0, from.indices.length); + return reuse; + } + return copy(from); + } + + @Override + public int getLength() { + return -1; + } + + @Override + public void serialize(SparseVector vector, DataOutputView target) throws IOException { + if (vector == null) { + throw new IllegalArgumentException("The vector must not be null."); + } + + target.writeInt(vector.n); + final int len = vector.values.length; + target.writeInt(len); + for (int i = 0; i < len; i++) { + target.writeInt(vector.indices[i]); + target.writeDouble(vector.values[i]); + } + } + + // Reads `len` int values from `source` into `indices` and `len` double values from `source` + // into `values`. + private void readSparseVectorArrays( + int[] indices, double[] values, DataInputView source, int len) throws IOException { + for (int i = 0; i < len; i++) { + indices[i] = source.readInt(); + values[i] = source.readDouble(); + } + } + + @Override + public SparseVector deserialize(DataInputView source) throws IOException { + int n = source.readInt(); + int len = source.readInt(); + int[] indices = new int[len]; + double[] values = new double[len]; + readSparseVectorArrays(indices, values, source, len); + return new SparseVector(n, indices, values); + } + + @Override + public SparseVector deserialize(SparseVector reuse, DataInputView source) throws IOException { + int n = source.readInt(); + int len = source.readInt(); + if (reuse.n == n && reuse.values.length == len) { + readSparseVectorArrays(reuse.indices, reuse.values, source, len); + return reuse; + } + + int[] indices = new int[len]; + double[] values = new double[len]; + readSparseVectorArrays(indices, values, source, len); + return new SparseVector(n, indices, values); + } + + @Override + public void copy(DataInputView source, DataOutputView target) throws IOException { + int n = source.readInt(); + int len = source.readInt(); + + target.writeInt(n); + target.writeInt(len); + + target.write(source, len * 12); + } + + @Override + public TypeSerializerSnapshot<SparseVector> snapshotConfiguration() { + return new SparseVectorSerializerSnapshot(); + } + + /** Serializer configuration snapshot for compatibility and format evolution. */ + @SuppressWarnings("WeakerAccess") + public static final class SparseVectorSerializerSnapshot + extends SimpleTypeSerializerSnapshot<SparseVector> { + + public SparseVectorSerializerSnapshot() { + super(() -> INSTANCE); + } + } +} diff --git a/flink-ml-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/DenseVectorTypeInfo.java b/flink-ml-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/SparseVectorTypeInfo.java similarity index 50% copy from flink-ml-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/DenseVectorTypeInfo.java copy to flink-ml-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/SparseVectorTypeInfo.java index 0239e17..06686f0 100644 --- a/flink-ml-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/DenseVectorTypeInfo.java +++ b/flink-ml-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/SparseVectorTypeInfo.java @@ -7,13 +7,14 @@ * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. */ package org.apache.flink.ml.linalg.typeinfo; @@ -21,39 +22,35 @@ package org.apache.flink.ml.linalg.typeinfo; import org.apache.flink.api.common.ExecutionConfig; import org.apache.flink.api.common.typeinfo.TypeInformation; import org.apache.flink.api.common.typeutils.TypeSerializer; -import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.SparseVector; -/** A {@link TypeInformation} for the {@link DenseVector} type. */ -public class DenseVectorTypeInfo extends TypeInformation<DenseVector> { - private static final long serialVersionUID = 1L; - - public static final DenseVectorTypeInfo INSTANCE = new DenseVectorTypeInfo(); - - public DenseVectorTypeInfo() {} +/** A {@link TypeInformation} for the {@link SparseVector} type. */ +public class SparseVectorTypeInfo extends TypeInformation<SparseVector> { + public static final SparseVectorTypeInfo INSTANCE = new SparseVectorTypeInfo(); @Override - public int getArity() { - return 1; + public boolean isBasicType() { + return false; } @Override - public int getTotalFields() { - return 1; + public boolean isTupleType() { + return false; } @Override - public Class<DenseVector> getTypeClass() { - return DenseVector.class; + public int getArity() { + return 3; } @Override - public boolean isBasicType() { - return false; + public int getTotalFields() { + return 3; } @Override - public boolean isTupleType() { - return false; + public Class<SparseVector> getTypeClass() { + return SparseVector.class; } @Override @@ -62,30 +59,27 @@ public class DenseVectorTypeInfo extends TypeInformation<DenseVector> { } @Override - @SuppressWarnings("unchecked") - public TypeSerializer<DenseVector> createSerializer(ExecutionConfig executionConfig) { - return new DenseVectorSerializer(); + public TypeSerializer<SparseVector> createSerializer(ExecutionConfig executionConfig) { + return SparseVectorSerializer.INSTANCE; } - // -------------------------------------------------------------------------------------------- - @Override - public int hashCode() { - return getClass().hashCode(); + public String toString() { + return "SparseVectorType"; } @Override public boolean equals(Object obj) { - return obj instanceof DenseVectorTypeInfo; + return obj instanceof SparseVectorTypeInfo; } @Override - public boolean canEqual(Object obj) { - return obj instanceof DenseVectorTypeInfo; + public int hashCode() { + return getClass().hashCode(); } @Override - public String toString() { - return "DenseVectorType"; + public boolean canEqual(Object obj) { + return obj instanceof SparseVectorTypeInfo; } } diff --git a/flink-ml-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/SparseVectorTypeInfoFactory.java b/flink-ml-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/SparseVectorTypeInfoFactory.java new file mode 100644 index 0000000..01c1036 --- /dev/null +++ b/flink-ml-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/SparseVectorTypeInfoFactory.java @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.flink.ml.linalg.typeinfo; + +import org.apache.flink.api.common.typeinfo.TypeInfoFactory; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.java.typeutils.TypeExtractor; +import org.apache.flink.ml.linalg.SparseVector; + +import java.lang.reflect.Type; +import java.util.Map; + +/** + * Used by {@link TypeExtractor} to create a {@link TypeInformation} for implementations of {@link + * SparseVector}. + */ +public class SparseVectorTypeInfoFactory extends TypeInfoFactory<SparseVector> { + @Override + public TypeInformation<SparseVector> createTypeInfo( + Type type, Map<String, TypeInformation<?>> map) { + return new SparseVectorTypeInfo(); + } +} diff --git a/flink-ml-core/src/main/java/org/apache/flink/ml/param/ParamValidators.java b/flink-ml-core/src/main/java/org/apache/flink/ml/param/ParamValidators.java index 925ccb2..e7d1436 100644 --- a/flink-ml-core/src/main/java/org/apache/flink/ml/param/ParamValidators.java +++ b/flink-ml-core/src/main/java/org/apache/flink/ml/param/ParamValidators.java @@ -95,4 +95,9 @@ public class ParamValidators { } }; } + + // Check if the parameter value array is not empty array. + public static <T> ParamValidator<T[]> nonEmptyArray() { + return value -> value != null && value.length > 0; + } } diff --git a/flink-ml-core/src/test/java/org/apache/flink/ml/api/StageTest.java b/flink-ml-core/src/test/java/org/apache/flink/ml/api/StageTest.java index 6ac630d..df0db64 100644 --- a/flink-ml-core/src/test/java/org/apache/flink/ml/api/StageTest.java +++ b/flink-ml-core/src/test/java/org/apache/flink/ml/api/StageTest.java @@ -388,5 +388,10 @@ public class StageTest { ParamValidator<Integer> notNull = ParamValidators.notNull(); Assert.assertTrue(notNull.validate(5)); Assert.assertFalse(notNull.validate(null)); + + ParamValidator<Object[]> nonEmptyArray = ParamValidators.nonEmptyArray(); + Assert.assertTrue(nonEmptyArray.validate(new String[] {"1"})); + Assert.assertFalse(nonEmptyArray.validate(null)); + Assert.assertFalse(nonEmptyArray.validate(new String[0])); } } diff --git a/flink-ml-core/src/test/java/org/apache/flink/ml/linalg/SparseVectorTest.java b/flink-ml-core/src/test/java/org/apache/flink/ml/linalg/SparseVectorTest.java new file mode 100644 index 0000000..0e7c349 --- /dev/null +++ b/flink-ml-core/src/test/java/org/apache/flink/ml/linalg/SparseVectorTest.java @@ -0,0 +1,132 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.linalg; + +import org.apache.flink.core.memory.DataInputViewStreamWrapper; +import org.apache.flink.core.memory.DataOutputViewStreamWrapper; +import org.apache.flink.ml.linalg.typeinfo.SparseVectorSerializer; + +import org.apache.commons.io.output.ByteArrayOutputStream; +import org.junit.Assert; +import org.junit.Test; + +import java.io.ByteArrayInputStream; +import java.io.IOException; + +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; + +/** Tests the behavior of Vectors. */ +public class SparseVectorTest { + @Test + public void testConstructor() { + int n = 4; + int[] indices = new int[] {0, 2, 3}; + double[] values = new double[] {0.1, 0.3, 0.4}; + + SparseVector vector = Vectors.sparse(n, indices, values); + assertEquals(n, vector.n); + assertArrayEquals(indices, vector.indices); + assertArrayEquals(values, vector.values, 1e-5); + assertEquals("(4, [0, 2, 3], [0.1, 0.3, 0.4])", vector.toString()); + } + + @Test + public void testDuplicateIndex() { + int n = 4; + int[] indices = new int[] {0, 2, 2}; + double[] values = new double[] {0.1, 0.3, 0.4}; + + try { + Vectors.sparse(n, indices, values); + Assert.fail("Expected IllegalArgumentException."); + } catch (Exception e) { + assertEquals(IllegalArgumentException.class, e.getClass()); + assertEquals("Indices duplicated.", e.getMessage()); + } + } + + @Test + public void testAllZeroVector() { + int n = 4; + SparseVector vector = Vectors.sparse(n, new int[0], new double[0]); + assertArrayEquals(vector.toArray(), new double[n], 1e-5); + } + + @Test + public void testUnsortedIndex() { + SparseVector vector; + + vector = Vectors.sparse(4, new int[] {2}, new double[] {0.3}); + assertEquals(4, vector.n); + assertArrayEquals(new int[] {2}, vector.indices); + assertArrayEquals(new double[] {0.3}, vector.values, 1e-5); + + vector = Vectors.sparse(4, new int[] {1, 2}, new double[] {0.2, 0.3}); + assertEquals(4, vector.n); + assertArrayEquals(new int[] {1, 2}, vector.indices); + assertArrayEquals(new double[] {0.2, 0.3}, vector.values, 1e-5); + + vector = Vectors.sparse(4, new int[] {2, 1}, new double[] {0.3, 0.2}); + assertEquals(4, vector.n); + assertArrayEquals(new int[] {1, 2}, vector.indices); + assertArrayEquals(new double[] {0.2, 0.3}, vector.values, 1e-5); + + vector = Vectors.sparse(4, new int[] {3, 2, 0}, new double[] {0.4, 0.3, 0.1}); + assertEquals(4, vector.n); + assertArrayEquals(new int[] {0, 2, 3}, vector.indices); + assertArrayEquals(new double[] {0.1, 0.3, 0.4}, vector.values, 1e-5); + + vector = Vectors.sparse(4, new int[] {2, 0, 3}, new double[] {0.3, 0.1, 0.4}); + assertEquals(4, vector.n); + assertArrayEquals(new int[] {0, 2, 3}, vector.indices); + assertArrayEquals(new double[] {0.1, 0.3, 0.4}, vector.values, 1e-5); + + vector = + Vectors.sparse( + 7, + new int[] {6, 5, 4, 3, 2, 1, 0}, + new double[] {0.7, 0.6, 0.5, 0.4, 0.3, 0.2, 0.1}); + assertEquals(7, vector.n); + assertArrayEquals(new int[] {0, 1, 2, 3, 4, 5, 6}, vector.indices); + assertArrayEquals(new double[] {0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7}, vector.values, 1e-5); + } + + @Test + public void testSerializer() throws IOException { + int n = 4; + int[] indices = new int[] {0, 2, 3}; + double[] values = new double[] {0.1, 0.3, 0.4}; + SparseVector vector = Vectors.sparse(n, indices, values); + SparseVectorSerializer serializer = SparseVectorSerializer.INSTANCE; + + ByteArrayOutputStream bOutput = new ByteArrayOutputStream(1024); + DataOutputViewStreamWrapper output = new DataOutputViewStreamWrapper(bOutput); + serializer.serialize(vector, output); + + byte[] b = bOutput.toByteArray(); + ByteArrayInputStream bInput = new ByteArrayInputStream(b); + DataInputViewStreamWrapper input = new DataInputViewStreamWrapper(bInput); + SparseVector vector2 = serializer.deserialize(input); + + assertEquals(vector.n, vector2.n); + assertArrayEquals(vector.indices, vector2.indices); + assertArrayEquals(vector.values, vector2.values, 1e-5); + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasHandleInvalid.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasHandleInvalid.java new file mode 100644 index 0000000..a7ea41a --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasHandleInvalid.java @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.param; + +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.param.ParamValidators; +import org.apache.flink.ml.param.StringParam; +import org.apache.flink.ml.param.WithParams; + +/** + * Interface for the shared handleInvalid param. + * + * <p>Supported options and the corresponding behavior to handle invalid entries is listed as + * follows. + * + * <ul> + * <li>error: raise an exception. + * <li>skip: filter out rows with bad values + * </ul> + */ +public interface HasHandleInvalid<T> extends WithParams<T> { + String ERROR_INVALID = "error"; + String SKIP_INVALID = "skip"; + + Param<String> HANDLE_INVALID = + new StringParam( + "handleInvalid", + "Strategy to handle invalid entries.", + ERROR_INVALID, + ParamValidators.inArray(ERROR_INVALID, SKIP_INVALID)); + + default String getHandleInvalid() { + return get(HANDLE_INVALID); + } + + default T setHandleInvalid(String value) { + set(HANDLE_INVALID, value); + return (T) this; + } +} diff --git a/flink-ml-core/src/main/java/org/apache/flink/ml/linalg/Vectors.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasInputCols.java similarity index 55% copy from flink-ml-core/src/main/java/org/apache/flink/ml/linalg/Vectors.java copy to flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasInputCols.java index a058755..c567de7 100644 --- a/flink-ml-core/src/main/java/org/apache/flink/ml/linalg/Vectors.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasInputCols.java @@ -16,13 +16,24 @@ * limitations under the License. */ -package org.apache.flink.ml.linalg; +package org.apache.flink.ml.common.param; -/** Utility methods for instantiating Vector. */ -public class Vectors { +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.param.ParamValidators; +import org.apache.flink.ml.param.StringArrayParam; +import org.apache.flink.ml.param.WithParams; - /** Creates a dense vector from its values. */ - public static DenseVector dense(double... values) { - return new DenseVector(values); +/** Interface for the shared inputCols param. */ +public interface HasInputCols<T> extends WithParams<T> { + Param<String[]> INPUT_COLS = + new StringArrayParam( + "inputCols", "Input column names.", null, ParamValidators.nonEmptyArray()); + + default String[] getInputCols() { + return get(INPUT_COLS); + } + + default T setInputCols(String... value) { + return set(INPUT_COLS, value); } } diff --git a/flink-ml-core/src/main/java/org/apache/flink/ml/linalg/Vectors.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasOutputCols.java similarity index 54% copy from flink-ml-core/src/main/java/org/apache/flink/ml/linalg/Vectors.java copy to flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasOutputCols.java index a058755..947501f 100644 --- a/flink-ml-core/src/main/java/org/apache/flink/ml/linalg/Vectors.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasOutputCols.java @@ -16,13 +16,24 @@ * limitations under the License. */ -package org.apache.flink.ml.linalg; +package org.apache.flink.ml.common.param; -/** Utility methods for instantiating Vector. */ -public class Vectors { +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.param.ParamValidators; +import org.apache.flink.ml.param.StringArrayParam; +import org.apache.flink.ml.param.WithParams; - /** Creates a dense vector from its values. */ - public static DenseVector dense(double... values) { - return new DenseVector(values); +/** Interface for the shared outputCols param. */ +public interface HasOutputCols<T> extends WithParams<T> { + Param<String[]> OUTPUT_COLS = + new StringArrayParam( + "outputCols", "Output column names.", null, ParamValidators.nonEmptyArray()); + + default String[] getOutputCols() { + return get(OUTPUT_COLS); + } + + default T setOutputCols(String... value) { + return set(OUTPUT_COLS, value); } } diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/onehotencoder/OneHotEncoder.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/onehotencoder/OneHotEncoder.java new file mode 100644 index 0000000..374d457 --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/onehotencoder/OneHotEncoder.java @@ -0,0 +1,148 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.feature.onehotencoder; + +import org.apache.flink.api.common.functions.FlatMapFunction; +import org.apache.flink.api.common.functions.MapPartitionFunction; +import org.apache.flink.api.common.typeinfo.Types; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.ml.api.Estimator; +import org.apache.flink.ml.common.datastream.MapPartitionFunctionWrapper; +import org.apache.flink.ml.common.param.HasHandleInvalid; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.util.ParamUtils; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.types.Row; +import org.apache.flink.util.Collector; +import org.apache.flink.util.Preconditions; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; + +/** + * An Estimator which implements the one-hot encoding algorithm. + * + * <p>Data of selected input columns should be indexed numbers in order for OneHotEncoder to + * function correctly. + * + * <p>See https://en.wikipedia.org/wiki/One-hot. + */ +public class OneHotEncoder + implements Estimator<OneHotEncoder, OneHotEncoderModel>, + OneHotEncoderParams<OneHotEncoder> { + private final Map<Param<?>, Object> paramMap = new HashMap<>(); + + public OneHotEncoder() { + ParamUtils.initializeMapWithDefaultValues(paramMap, this); + } + + @Override + public OneHotEncoderModel fit(Table... inputs) { + Preconditions.checkArgument(inputs.length == 1); + Preconditions.checkArgument(getHandleInvalid().equals(HasHandleInvalid.ERROR_INVALID)); + + final String[] inputCols = getInputCols(); + + StreamTableEnvironment tEnv = + (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment(); + DataStream<Tuple2<Integer, Integer>> modelData = + tEnv.toDataStream(inputs[0]) + .flatMap(new ExtractInputColsValueFunction(inputCols)) + .keyBy(columnIdAndValue -> columnIdAndValue.f0) + .transform( + "findMaxIndex", + Types.TUPLE(Types.INT, Types.INT), + new MapPartitionFunctionWrapper<>(new FindMaxIndexFunction())); + + OneHotEncoderModel model = + new OneHotEncoderModel().setModelData(tEnv.fromDataStream(modelData)); + ReadWriteUtils.updateExistingParams(model, paramMap); + return model; + } + + @Override + public void save(String path) throws IOException { + ReadWriteUtils.saveMetadata(this, path); + } + + public static OneHotEncoder load(StreamExecutionEnvironment env, String path) + throws IOException { + return ReadWriteUtils.loadStageParam(path); + } + + @Override + public Map<Param<?>, Object> getParamMap() { + return paramMap; + } + + /** + * Extract values of input columns of input data. + * + * <p>Input: rows of input data containing designated input columns + * + * <p>Output: Pairs of column index and value stored in those columns + */ + private static class ExtractInputColsValueFunction + implements FlatMapFunction<Row, Tuple2<Integer, Integer>> { + private final String[] inputCols; + + private ExtractInputColsValueFunction(String[] inputCols) { + this.inputCols = inputCols; + } + + @Override + public void flatMap(Row row, Collector<Tuple2<Integer, Integer>> collector) { + for (int i = 0; i < inputCols.length; i++) { + Number number = (Number) row.getField(inputCols[i]); + Preconditions.checkArgument( + number.intValue() == number.doubleValue(), + String.format("Value %s cannot be parsed as indexed integer.", number)); + Preconditions.checkArgument( + number.intValue() >= 0, "Negative value not supported."); + collector.collect(new Tuple2<>(i, number.intValue())); + } + } + } + + /** Function to find the max index value for each column. */ + private static class FindMaxIndexFunction + implements MapPartitionFunction<Tuple2<Integer, Integer>, Tuple2<Integer, Integer>> { + + @Override + public void mapPartition( + Iterable<Tuple2<Integer, Integer>> iterable, + Collector<Tuple2<Integer, Integer>> collector) { + Map<Integer, Integer> map = new HashMap<>(); + for (Tuple2<Integer, Integer> value : iterable) { + map.put( + value.f0, + Math.max(map.getOrDefault(value.f0, Integer.MIN_VALUE), value.f1)); + } + for (Map.Entry<Integer, Integer> entry : map.entrySet()) { + collector.collect(new Tuple2<>(entry.getKey(), entry.getValue())); + } + } + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/onehotencoder/OneHotEncoderModel.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/onehotencoder/OneHotEncoderModel.java new file mode 100644 index 0000000..447fe77 --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/onehotencoder/OneHotEncoderModel.java @@ -0,0 +1,190 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.feature.onehotencoder; + +import org.apache.flink.api.common.functions.RichMapFunction; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.api.java.typeutils.RowTypeInfo; +import org.apache.flink.ml.api.Model; +import org.apache.flink.ml.common.broadcast.BroadcastUtils; +import org.apache.flink.ml.common.datastream.TableUtils; +import org.apache.flink.ml.common.param.HasHandleInvalid; +import org.apache.flink.ml.linalg.Vectors; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.util.ParamUtils; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.table.runtime.typeutils.ExternalTypeInfo; +import org.apache.flink.types.Row; +import org.apache.flink.util.Preconditions; + +import org.apache.commons.lang3.ArrayUtils; + +import java.io.IOException; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Vector; +import java.util.function.Function; + +/** + * A Model which encodes data into one-hot format using the model data computed by {@link + * OneHotEncoder}. + */ +public class OneHotEncoderModel + implements Model<OneHotEncoderModel>, OneHotEncoderParams<OneHotEncoderModel> { + private final Map<Param<?>, Object> paramMap = new HashMap<>(); + private Table modelDataTable; + + public OneHotEncoderModel() { + ParamUtils.initializeMapWithDefaultValues(paramMap, this); + } + + @Override + public Table[] transform(Table... inputs) { + final String[] inputCols = getInputCols(); + final String[] outputCols = getOutputCols(); + final boolean dropLast = getDropLast(); + final String broadcastModelKey = "OneHotModelStream"; + + Preconditions.checkArgument(getHandleInvalid().equals(HasHandleInvalid.ERROR_INVALID)); + Preconditions.checkArgument(inputs.length == 1); + Preconditions.checkArgument(inputCols.length == outputCols.length); + + RowTypeInfo inputTypeInfo = TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema()); + RowTypeInfo outputTypeInfo = + new RowTypeInfo( + ArrayUtils.addAll( + inputTypeInfo.getFieldTypes(), + Collections.nCopies( + outputCols.length, + ExternalTypeInfo.of(Vector.class)) + .toArray(new TypeInformation[0])), + ArrayUtils.addAll(inputTypeInfo.getFieldNames(), outputCols)); + + StreamTableEnvironment tEnv = + (StreamTableEnvironment) ((TableImpl) modelDataTable).getTableEnvironment(); + DataStream<Row> input = tEnv.toDataStream(inputs[0]); + DataStream<Tuple2<Integer, Integer>> modelStream = + OneHotEncoderModelData.getModelDataStream(modelDataTable); + + Function<List<DataStream<?>>, DataStream<Row>> function = + dataStreams -> { + DataStream stream = dataStreams.get(0); + return stream.map( + new GenerateOutputsFunction(inputCols, dropLast, broadcastModelKey), + outputTypeInfo); + }; + + DataStream<Row> output = + BroadcastUtils.withBroadcastStream( + Collections.singletonList(input), + Collections.singletonMap(broadcastModelKey, modelStream), + function); + + Table outputTable = tEnv.fromDataStream(output); + + return new Table[] {outputTable}; + } + + @Override + public void save(String path) throws IOException { + ReadWriteUtils.saveMetadata(this, path); + ReadWriteUtils.saveModelData( + OneHotEncoderModelData.getModelDataStream(modelDataTable), + path, + new OneHotEncoderModelData.ModelDataEncoder()); + } + + public static OneHotEncoderModel load(StreamExecutionEnvironment env, String path) + throws IOException { + StreamTableEnvironment tEnv = StreamTableEnvironment.create(env); + OneHotEncoderModel model = ReadWriteUtils.loadStageParam(path); + DataStream<Tuple2<Integer, Integer>> modelData = + ReadWriteUtils.loadModelData( + env, path, new OneHotEncoderModelData.ModelDataStreamFormat()); + return model.setModelData(tEnv.fromDataStream(modelData)); + } + + @Override + public Map<Param<?>, Object> getParamMap() { + return paramMap; + } + + @Override + public OneHotEncoderModel setModelData(Table... inputs) { + Preconditions.checkArgument(inputs.length == 1); + modelDataTable = inputs[0]; + return this; + } + + @Override + public Table[] getModelData() { + return new Table[] {modelDataTable}; + } + + private static class GenerateOutputsFunction extends RichMapFunction<Row, Row> { + private final String[] inputCols; + private final boolean dropLast; + private final String broadcastModelKey; + private List<Tuple2<Integer, Integer>> model = null; + + public GenerateOutputsFunction( + String[] inputCols, boolean dropLast, String broadcastModelKey) { + this.inputCols = inputCols; + this.dropLast = dropLast; + this.broadcastModelKey = broadcastModelKey; + } + + @Override + public Row map(Row row) { + if (model == null) { + model = getRuntimeContext().getBroadcastVariable(broadcastModelKey); + } + int[] categorySizes = new int[model.size()]; + int offset = dropLast ? 0 : 1; + for (Tuple2<Integer, Integer> tup : model) { + categorySizes[tup.f0] = tup.f1 + offset; + } + Row result = new Row(categorySizes.length); + for (int i = 0; i < categorySizes.length; i++) { + Number number = (Number) row.getField(inputCols[i]); + Preconditions.checkArgument( + number.intValue() == number.doubleValue(), + String.format("Value %s cannot be parsed as indexed integer.", number)); + int idx = number.intValue(); + if (idx == categorySizes[i]) { + result.setField(i, Vectors.sparse(categorySizes[i], new int[0], new double[0])); + } else { + result.setField( + i, + Vectors.sparse(categorySizes[i], new int[] {idx}, new double[] {1.0})); + } + } + + return Row.join(row, result); + } + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/onehotencoder/OneHotEncoderModelData.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/onehotencoder/OneHotEncoderModelData.java new file mode 100644 index 0000000..f267784 --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/onehotencoder/OneHotEncoderModelData.java @@ -0,0 +1,109 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.feature.onehotencoder; + +import org.apache.flink.api.common.functions.MapFunction; +import org.apache.flink.api.common.serialization.Encoder; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.common.typeinfo.Types; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.connector.file.src.reader.SimpleStreamFormat; +import org.apache.flink.core.fs.FSDataInputStream; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.types.Row; + +import com.esotericsoftware.kryo.io.Input; +import com.esotericsoftware.kryo.io.Output; + +import java.io.IOException; +import java.io.OutputStream; + +/** + * Model data of {@link OneHotEncoderModel}. + * + * <p>This class also provides methods to convert model data from Table to Datastream, and classes + * to save/load model data. + */ +public class OneHotEncoderModelData { + /** + * Converts the table model to a data stream. + * + * @param modelData The table model data. + * @return The data stream model data. + */ + public static DataStream<Tuple2<Integer, Integer>> getModelDataStream(Table modelData) { + StreamTableEnvironment tEnv = + (StreamTableEnvironment) ((TableImpl) modelData).getTableEnvironment(); + return tEnv.toDataStream(modelData) + .map( + new MapFunction<Row, Tuple2<Integer, Integer>>() { + @Override + public Tuple2<Integer, Integer> map(Row row) { + return new Tuple2<>( + (int) row.getField("f0"), (int) row.getField("f1")); + } + }); + } + + /** Data encoder for the OneHotEncoder model data. */ + public static class ModelDataEncoder implements Encoder<Tuple2<Integer, Integer>> { + @Override + public void encode(Tuple2<Integer, Integer> modelData, OutputStream outputStream) { + Output output = new Output(outputStream); + output.writeInt(modelData.f0); + output.writeInt(modelData.f1); + output.flush(); + } + } + + /** Data decoder for the OneHotEncoder model data. */ + public static class ModelDataStreamFormat extends SimpleStreamFormat<Tuple2<Integer, Integer>> { + @Override + public Reader<Tuple2<Integer, Integer>> createReader( + Configuration config, FSDataInputStream stream) { + return new Reader<Tuple2<Integer, Integer>>() { + private final Input input = new Input(stream); + + @Override + public Tuple2<Integer, Integer> read() { + if (input.eof()) { + return null; + } + int f0 = input.readInt(); + int f1 = input.readInt(); + return new Tuple2<>(f0, f1); + } + + @Override + public void close() throws IOException { + stream.close(); + } + }; + } + + @Override + public TypeInformation<Tuple2<Integer, Integer>> getProducedType() { + return Types.TUPLE(Types.INT, Types.INT); + } + } +} diff --git a/flink-ml-core/src/main/java/org/apache/flink/ml/linalg/Vectors.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/onehotencoder/OneHotEncoderParams.java similarity index 51% copy from flink-ml-core/src/main/java/org/apache/flink/ml/linalg/Vectors.java copy to flink-ml-lib/src/main/java/org/apache/flink/ml/feature/onehotencoder/OneHotEncoderParams.java index a058755..9b57159 100644 --- a/flink-ml-core/src/main/java/org/apache/flink/ml/linalg/Vectors.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/onehotencoder/OneHotEncoderParams.java @@ -16,13 +16,29 @@ * limitations under the License. */ -package org.apache.flink.ml.linalg; +package org.apache.flink.ml.feature.onehotencoder; -/** Utility methods for instantiating Vector. */ -public class Vectors { +import org.apache.flink.ml.common.param.HasHandleInvalid; +import org.apache.flink.ml.common.param.HasInputCols; +import org.apache.flink.ml.common.param.HasOutputCols; +import org.apache.flink.ml.param.BooleanParam; +import org.apache.flink.ml.param.Param; - /** Creates a dense vector from its values. */ - public static DenseVector dense(double... values) { - return new DenseVector(values); +/** + * Params of OneHotEncoderModel. + * + * @param <T> The class type of this instance. + */ +public interface OneHotEncoderParams<T> + extends HasInputCols<T>, HasOutputCols<T>, HasHandleInvalid<T> { + Param<Boolean> DROP_LAST = + new BooleanParam("dropLast", "Whether to drop the last category.", true); + + default boolean getDropLast() { + return get(DROP_LAST); + } + + default T setDropLast(boolean value) { + return set(DROP_LAST, value); } } diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/OneHotEncoderTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/OneHotEncoderTest.java new file mode 100644 index 0000000..51f9735 --- /dev/null +++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/OneHotEncoderTest.java @@ -0,0 +1,294 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.feature; + +import org.apache.flink.api.common.restartstrategy.RestartStrategies; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.ml.common.param.HasHandleInvalid; +import org.apache.flink.ml.feature.onehotencoder.OneHotEncoder; +import org.apache.flink.ml.feature.onehotencoder.OneHotEncoderModel; +import org.apache.flink.ml.feature.onehotencoder.OneHotEncoderModelData; +import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.Vectors; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.ml.util.StageTestUtils; +import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.types.Row; +import org.apache.flink.util.CloseableIterator; + +import org.junit.Assert; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; + +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; + +/** Tests OneHotEncoder and OneHotEncoderModel. */ +public class OneHotEncoderTest { + @Rule public final TemporaryFolder tempFolder = new TemporaryFolder(); + + private StreamExecutionEnvironment env; + private StreamTableEnvironment tEnv; + private Table trainTable; + private Table predictTable; + private Map<Double, Vector>[] expectedOutput; + private OneHotEncoder estimator; + + @Before + public void before() { + Configuration config = new Configuration(); + config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true); + env = StreamExecutionEnvironment.getExecutionEnvironment(config); + env.setParallelism(4); + env.enableCheckpointing(100); + env.setRestartStrategy(RestartStrategies.noRestart()); + tEnv = StreamTableEnvironment.create(env); + + List<Row> trainData = Arrays.asList(Row.of(0.0), Row.of(1.0), Row.of(2.0), Row.of(0.0)); + + trainTable = tEnv.fromDataStream(env.fromCollection(trainData)).as("input"); + + List<Row> predictData = Arrays.asList(Row.of(0.0), Row.of(1.0), Row.of(2.0)); + + predictTable = tEnv.fromDataStream(env.fromCollection(predictData)).as("input"); + + expectedOutput = + new HashMap[] { + new HashMap<Double, Vector>() { + { + put(0.0, Vectors.sparse(2, new int[] {0}, new double[] {1.0})); + put(1.0, Vectors.sparse(2, new int[] {1}, new double[] {1.0})); + put(2.0, Vectors.sparse(2, new int[0], new double[0])); + } + } + }; + + estimator = new OneHotEncoder().setInputCols("input").setOutputCols("output"); + } + + /** + * Executes a given table and collect its results. Results are returned as a map array. Each + * element in the array is a map corresponding to a input column whose key is the original value + * in the input column, value is the one-hot encoding result of that value. + * + * @param table A table to be executed and to have its result collected + * @param inputCols Name of the input columns + * @param outputCols Name of the output columns containing one-hot encoding result + * @return An array of map containing the collected results for each input column + */ + private static Map<Double, Vector>[] executeAndCollect( + Table table, String[] inputCols, String[] outputCols) { + Map<Double, Vector>[] maps = new HashMap[inputCols.length]; + for (int i = 0; i < inputCols.length; i++) { + maps[i] = new HashMap<>(); + } + for (CloseableIterator<Row> it = table.execute().collect(); it.hasNext(); ) { + Row row = it.next(); + for (int i = 0; i < inputCols.length; i++) { + maps[i].put( + ((Number) row.getField(inputCols[i])).doubleValue(), + (Vector) row.getField(outputCols[i])); + } + } + return maps; + } + + @Test + public void testParam() { + OneHotEncoder estimator = new OneHotEncoder(); + + assertTrue(estimator.getDropLast()); + + estimator.setInputCols("test_input").setOutputCols("test_output").setDropLast(false); + + assertArrayEquals(new String[] {"test_input"}, estimator.getInputCols()); + assertArrayEquals(new String[] {"test_output"}, estimator.getOutputCols()); + assertFalse(estimator.getDropLast()); + + OneHotEncoderModel model = new OneHotEncoderModel(); + + assertTrue(model.getDropLast()); + + model.setInputCols("test_input").setOutputCols("test_output").setDropLast(false); + + assertArrayEquals(new String[] {"test_input"}, model.getInputCols()); + assertArrayEquals(new String[] {"test_output"}, model.getOutputCols()); + assertFalse(model.getDropLast()); + } + + @Test + public void testFitAndPredict() { + OneHotEncoderModel model = estimator.fit(trainTable); + Table outputTable = model.transform(predictTable)[0]; + Map<Double, Vector>[] actualOutput = + executeAndCollect(outputTable, model.getInputCols(), model.getOutputCols()); + assertArrayEquals(expectedOutput, actualOutput); + } + + @Test + public void testDropLast() { + estimator.setDropLast(false); + + expectedOutput = + new HashMap[] { + new HashMap<Double, Vector>() { + { + put(0.0, Vectors.sparse(3, new int[] {0}, new double[] {1.0})); + put(1.0, Vectors.sparse(3, new int[] {1}, new double[] {1.0})); + put(2.0, Vectors.sparse(3, new int[] {2}, new double[] {1.0})); + } + } + }; + + OneHotEncoderModel model = estimator.fit(trainTable); + Table outputTable = model.transform(predictTable)[0]; + Map<Double, Vector>[] actualOutput = + executeAndCollect(outputTable, model.getInputCols(), model.getOutputCols()); + assertArrayEquals(expectedOutput, actualOutput); + } + + @Test + public void testInputDataType() { + List<Row> trainData = Arrays.asList(Row.of(0), Row.of(1), Row.of(2), Row.of(0)); + + trainTable = tEnv.fromDataStream(env.fromCollection(trainData)).as("input"); + + List<Row> predictData = Arrays.asList(Row.of(0), Row.of(1), Row.of(2)); + predictTable = tEnv.fromDataStream(env.fromCollection(predictData)).as("input"); + + expectedOutput = + new HashMap[] { + new HashMap<Double, Vector>() { + { + put(0.0, Vectors.sparse(2, new int[] {0}, new double[] {1.0})); + put(1.0, Vectors.sparse(2, new int[] {1}, new double[] {1.0})); + put(2.0, Vectors.sparse(2, new int[0], new double[0])); + } + } + }; + + OneHotEncoderModel model = estimator.fit(trainTable); + Table outputTable = model.transform(predictTable)[0]; + Map<Double, Vector>[] actualOutput = + executeAndCollect(outputTable, model.getInputCols(), model.getOutputCols()); + assertArrayEquals(expectedOutput, actualOutput); + } + + @Test + public void testNotSupportedHandleInvalidOptions() { + estimator.setHandleInvalid(HasHandleInvalid.SKIP_INVALID); + try { + estimator.fit(trainTable); + Assert.fail("Expected IllegalArgumentException"); + } catch (Exception e) { + assertEquals(IllegalArgumentException.class, ((Throwable) e).getClass()); + } + } + + @Test + public void testNonIndexedTrainData() { + List<Row> trainData = Arrays.asList(Row.of(0.5), Row.of(1.0), Row.of(2.0), Row.of(0.0)); + + trainTable = tEnv.fromDataStream(env.fromCollection(trainData)).as("input"); + OneHotEncoderModel model = estimator.fit(trainTable); + Table outputTable = model.transform(predictTable)[0]; + try { + outputTable.execute().collect().next(); + Assert.fail("Expected IllegalArgumentException"); + } catch (Exception e) { + Throwable exception = e; + while (exception.getCause() != null) { + exception = exception.getCause(); + } + assertEquals(IllegalArgumentException.class, exception.getClass()); + assertEquals("Value 0.5 cannot be parsed as indexed integer.", exception.getMessage()); + } + } + + @Test + public void testNonIndexedPredictData() { + List<Row> predictData = Arrays.asList(Row.of(0.5), Row.of(1.0), Row.of(2.0), Row.of(0.0)); + + predictTable = tEnv.fromDataStream(env.fromCollection(predictData)).as("input"); + OneHotEncoderModel model = estimator.fit(trainTable); + Table outputTable = model.transform(predictTable)[0]; + try { + outputTable.execute().collect().next(); + Assert.fail("Expected IllegalArgumentException"); + } catch (Exception e) { + Throwable exception = e; + while (exception.getCause() != null) { + exception = exception.getCause(); + } + assertEquals(IllegalArgumentException.class, exception.getClass()); + assertEquals("Value 0.5 cannot be parsed as indexed integer.", exception.getMessage()); + } + } + + @Test + public void testSaveLoad() throws Exception { + estimator = + StageTestUtils.saveAndReload( + env, estimator, tempFolder.newFolder().getAbsolutePath()); + OneHotEncoderModel model = estimator.fit(trainTable); + model = StageTestUtils.saveAndReload(env, model, tempFolder.newFolder().getAbsolutePath()); + Table outputTable = model.transform(predictTable)[0]; + Map<Double, Vector>[] actualOutput = + executeAndCollect(outputTable, model.getInputCols(), model.getOutputCols()); + assertArrayEquals(expectedOutput, actualOutput); + } + + @Test + public void testGetModelData() throws Exception { + OneHotEncoderModel model = estimator.fit(trainTable); + Tuple2<Integer, Integer> expected = new Tuple2<>(0, 2); + Tuple2<Integer, Integer> actual = + OneHotEncoderModelData.getModelDataStream(model.getModelData()[0]) + .executeAndCollect() + .next(); + assertEquals(expected, actual); + } + + @Test + public void testSetModelData() { + OneHotEncoderModel modelA = estimator.fit(trainTable); + + Table modelData = modelA.getModelData()[0]; + OneHotEncoderModel modelB = new OneHotEncoderModel().setModelData(modelData); + ReadWriteUtils.updateExistingParams(modelB, modelA.getParamMap()); + + Table outputTable = modelB.transform(predictTable)[0]; + Map<Double, Vector>[] actualOutput = + executeAndCollect(outputTable, modelB.getInputCols(), modelB.getOutputCols()); + assertArrayEquals(expectedOutput, actualOutput); + } +}