This is an automated email from the ASF dual-hosted git repository.

gaoyunhaii pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink-ml.git


The following commit(s) were added to refs/heads/master by this push:
     new 79d0c24  [FLINK-24955] Add Estimator and Transformer for One Hot 
Encoder
79d0c24 is described below

commit 79d0c247e0cc900c9a16e045a99000423003780b
Author: Yunfeng Zhou <yuri.zhouyunf...@outlook.com>
AuthorDate: Thu Dec 2 15:32:17 2021 +0800

    [FLINK-24955] Add Estimator and Transformer for One Hot Encoder
    
    This closes #37.
---
 .../org/apache/flink/ml/linalg/SparseVector.java   | 168 ++++++++++++
 .../java/org/apache/flink/ml/linalg/Vectors.java   |   5 +
 .../ml/linalg/typeinfo/DenseVectorSerializer.java  |   8 +-
 .../ml/linalg/typeinfo/DenseVectorTypeInfo.java    |   2 +-
 .../ml/linalg/typeinfo/SparseVectorSerializer.java | 151 +++++++++++
 ...ctorTypeInfo.java => SparseVectorTypeInfo.java} |  66 +++--
 .../typeinfo/SparseVectorTypeInfoFactory.java      |  40 +++
 .../org/apache/flink/ml/param/ParamValidators.java |   5 +
 .../java/org/apache/flink/ml/api/StageTest.java    |   5 +
 .../apache/flink/ml/linalg/SparseVectorTest.java   | 132 +++++++++
 .../flink/ml/common/param/HasHandleInvalid.java    |  56 ++++
 .../apache/flink/ml/common/param/HasInputCols.java |  23 +-
 .../flink/ml/common/param/HasOutputCols.java       |  23 +-
 .../ml/feature/onehotencoder/OneHotEncoder.java    | 148 +++++++++++
 .../feature/onehotencoder/OneHotEncoderModel.java  | 190 +++++++++++++
 .../onehotencoder/OneHotEncoderModelData.java      | 109 ++++++++
 .../feature/onehotencoder/OneHotEncoderParams.java |  28 +-
 .../apache/flink/ml/feature/OneHotEncoderTest.java | 294 +++++++++++++++++++++
 18 files changed, 1393 insertions(+), 60 deletions(-)

diff --git 
a/flink-ml-core/src/main/java/org/apache/flink/ml/linalg/SparseVector.java 
b/flink-ml-core/src/main/java/org/apache/flink/ml/linalg/SparseVector.java
new file mode 100644
index 0000000..4e683a4
--- /dev/null
+++ b/flink-ml-core/src/main/java/org/apache/flink/ml/linalg/SparseVector.java
@@ -0,0 +1,168 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.linalg;
+
+import org.apache.flink.api.common.typeinfo.TypeInfo;
+import org.apache.flink.ml.linalg.typeinfo.SparseVectorTypeInfoFactory;
+import org.apache.flink.util.Preconditions;
+
+import java.util.Arrays;
+import java.util.Objects;
+
+/** A sparse vector of double values. */
+@TypeInfo(SparseVectorTypeInfoFactory.class)
+public class SparseVector implements Vector {
+    public final int n;
+    public final int[] indices;
+    public final double[] values;
+
+    public SparseVector(int n, int[] indices, double[] values) {
+        this.n = n;
+        this.indices = indices;
+        this.values = values;
+        if (!isIndicesSorted()) {
+            sortIndices();
+        }
+        validateSortedData();
+    }
+
+    @Override
+    public int size() {
+        return n;
+    }
+
+    @Override
+    public double get(int i) {
+        int pos = Arrays.binarySearch(indices, i);
+        if (pos >= 0) {
+            return values[pos];
+        }
+        return 0.;
+    }
+
+    @Override
+    public double[] toArray() {
+        double[] result = new double[n];
+        for (int i = 0; i < indices.length; i++) {
+            result[indices[i]] = values[i];
+        }
+        return result;
+    }
+
+    @Override
+    public boolean equals(Object o) {
+        if (this == o) {
+            return true;
+        }
+        if (o == null || getClass() != o.getClass()) {
+            return false;
+        }
+        SparseVector that = (SparseVector) o;
+        return n == that.n
+                && Arrays.equals(indices, that.indices)
+                && Arrays.equals(values, that.values);
+    }
+
+    @Override
+    public int hashCode() {
+        int result = Objects.hash(n);
+        result = 31 * result + Arrays.hashCode(indices);
+        result = 31 * result + Arrays.hashCode(values);
+        return result;
+    }
+
+    /**
+     * Checks whether input data is validate.
+     *
+     * <p>This function does the following checks:
+     *
+     * <ul>
+     *   <li>The indices array and values array are of the same size.
+     *   <li>vector indices are in valid range.
+     *   <li>vector indices are unique.
+     * </ul>
+     *
+     * <p>This function works as expected only when indices are sorted.
+     */
+    private void validateSortedData() {
+        Preconditions.checkArgument(
+                indices.length == values.length,
+                "Indices size and values size should be the same.");
+        if (this.indices.length > 0) {
+            Preconditions.checkArgument(
+                    this.indices[0] >= 0 && this.indices[this.indices.length - 
1] < this.n,
+                    "Index out of bound.");
+        }
+        for (int i = 1; i < this.indices.length; i++) {
+            Preconditions.checkArgument(
+                    this.indices[i] > this.indices[i - 1], "Indices 
duplicated.");
+        }
+    }
+
+    private boolean isIndicesSorted() {
+        for (int i = 1; i < this.indices.length; i++) {
+            if (this.indices[i] < this.indices[i - 1]) {
+                return false;
+            }
+        }
+        return true;
+    }
+
+    /** Sorts the indices and values. */
+    private void sortIndices() {
+        sortImpl(this.indices, this.values, 0, this.indices.length - 1);
+    }
+
+    /** Sorts the indices and values using quick sort. */
+    private static void sortImpl(int[] indices, double[] values, int low, int 
high) {
+        int pivotPos = (low + high) / 2;
+        int pivot = indices[pivotPos];
+        swapIndexAndValue(indices, values, pivotPos, high);
+
+        int pos = low - 1;
+        for (int i = low; i <= high; i++) {
+            if (indices[i] <= pivot) {
+                pos++;
+                swapIndexAndValue(indices, values, pos, i);
+            }
+        }
+        if (high > pos + 1) {
+            sortImpl(indices, values, pos + 1, high);
+        }
+        if (pos - 1 > low) {
+            sortImpl(indices, values, low, pos - 1);
+        }
+    }
+
+    private static void swapIndexAndValue(int[] indices, double[] values, int 
index1, int index2) {
+        int tempIndex = indices[index1];
+        indices[index1] = indices[index2];
+        indices[index2] = tempIndex;
+        double tempValue = values[index1];
+        values[index1] = values[index2];
+        values[index2] = tempValue;
+    }
+
+    @Override
+    public String toString() {
+        String sbr =
+                "(" + n + ", " + Arrays.toString(indices) + ", " + 
Arrays.toString(values) + ")";
+        return sbr;
+    }
+}
diff --git 
a/flink-ml-core/src/main/java/org/apache/flink/ml/linalg/Vectors.java 
b/flink-ml-core/src/main/java/org/apache/flink/ml/linalg/Vectors.java
index a058755..424b27f 100644
--- a/flink-ml-core/src/main/java/org/apache/flink/ml/linalg/Vectors.java
+++ b/flink-ml-core/src/main/java/org/apache/flink/ml/linalg/Vectors.java
@@ -25,4 +25,9 @@ public class Vectors {
     public static DenseVector dense(double... values) {
         return new DenseVector(values);
     }
+
+    /** Creates a sparse vector from its values. */
+    public static SparseVector sparse(int size, int[] indices, double[] 
values) {
+        return new SparseVector(size, indices, values);
+    }
 }
diff --git 
a/flink-ml-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/DenseVectorSerializer.java
 
b/flink-ml-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/DenseVectorSerializer.java
index 3cbde53..153a20f 100644
--- 
a/flink-ml-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/DenseVectorSerializer.java
+++ 
b/flink-ml-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/DenseVectorSerializer.java
@@ -29,14 +29,14 @@ import org.apache.flink.ml.linalg.DenseVector;
 import java.io.IOException;
 import java.util.Arrays;
 
-/** Specialized serializer for {@code DenseVector}. */
+/** Specialized serializer for {@link DenseVector}. */
 public final class DenseVectorSerializer extends 
TypeSerializerSingleton<DenseVector> {
 
     private static final long serialVersionUID = 1L;
 
     private static final double[] EMPTY = new double[0];
 
-    private static final DenseVectorSerializer INSTANCE = new 
DenseVectorSerializer();
+    public static final DenseVectorSerializer INSTANCE = new 
DenseVectorSerializer();
 
     @Override
     public boolean isImmutableType() {
@@ -84,9 +84,7 @@ public final class DenseVectorSerializer extends 
TypeSerializerSingleton<DenseVe
     public DenseVector deserialize(DataInputView source) throws IOException {
         int len = source.readInt();
         double[] values = new double[len];
-        for (int i = 0; i < len; i++) {
-            values[i] = source.readDouble();
-        }
+        readDoubleArray(values, source, len);
         return new DenseVector(values);
     }
 
diff --git 
a/flink-ml-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/DenseVectorTypeInfo.java
 
b/flink-ml-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/DenseVectorTypeInfo.java
index 0239e17..765cacb 100644
--- 
a/flink-ml-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/DenseVectorTypeInfo.java
+++ 
b/flink-ml-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/DenseVectorTypeInfo.java
@@ -64,7 +64,7 @@ public class DenseVectorTypeInfo extends 
TypeInformation<DenseVector> {
     @Override
     @SuppressWarnings("unchecked")
     public TypeSerializer<DenseVector> createSerializer(ExecutionConfig 
executionConfig) {
-        return new DenseVectorSerializer();
+        return DenseVectorSerializer.INSTANCE;
     }
 
     // 
--------------------------------------------------------------------------------------------
diff --git 
a/flink-ml-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/SparseVectorSerializer.java
 
b/flink-ml-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/SparseVectorSerializer.java
new file mode 100644
index 0000000..2c922a9
--- /dev/null
+++ 
b/flink-ml-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/SparseVectorSerializer.java
@@ -0,0 +1,151 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.flink.ml.linalg.typeinfo;
+
+import org.apache.flink.api.common.typeutils.SimpleTypeSerializerSnapshot;
+import org.apache.flink.api.common.typeutils.TypeSerializerSnapshot;
+import org.apache.flink.api.common.typeutils.base.TypeSerializerSingleton;
+import org.apache.flink.core.memory.DataInputView;
+import org.apache.flink.core.memory.DataOutputView;
+import org.apache.flink.ml.linalg.SparseVector;
+
+import java.io.IOException;
+import java.util.Arrays;
+
+/** Specialized serializer for {@link SparseVector}. */
+public final class SparseVectorSerializer extends 
TypeSerializerSingleton<SparseVector> {
+
+    private static final long serialVersionUID = 1L;
+
+    private static final double[] EMPTY_DOUBLE_ARRAY = new double[0];
+
+    private static final int[] EMPTY_INT_ARRAY = new int[0];
+
+    public static final SparseVectorSerializer INSTANCE = new 
SparseVectorSerializer();
+
+    @Override
+    public boolean isImmutableType() {
+        return false;
+    }
+
+    @Override
+    public SparseVector createInstance() {
+        return new SparseVector(0, EMPTY_INT_ARRAY, EMPTY_DOUBLE_ARRAY);
+    }
+
+    @Override
+    public SparseVector copy(SparseVector from) {
+        return new SparseVector(
+                from.n,
+                Arrays.copyOf(from.indices, from.indices.length),
+                Arrays.copyOf(from.values, from.values.length));
+    }
+
+    @Override
+    public SparseVector copy(SparseVector from, SparseVector reuse) {
+        if (from.values.length == reuse.values.length && from.n == reuse.n) {
+            System.arraycopy(from.values, 0, reuse.values, 0, 
from.values.length);
+            System.arraycopy(from.indices, 0, reuse.indices, 0, 
from.indices.length);
+            return reuse;
+        }
+        return copy(from);
+    }
+
+    @Override
+    public int getLength() {
+        return -1;
+    }
+
+    @Override
+    public void serialize(SparseVector vector, DataOutputView target) throws 
IOException {
+        if (vector == null) {
+            throw new IllegalArgumentException("The vector must not be null.");
+        }
+
+        target.writeInt(vector.n);
+        final int len = vector.values.length;
+        target.writeInt(len);
+        for (int i = 0; i < len; i++) {
+            target.writeInt(vector.indices[i]);
+            target.writeDouble(vector.values[i]);
+        }
+    }
+
+    // Reads `len` int values from `source` into `indices` and `len` double 
values from `source`
+    // into `values`.
+    private void readSparseVectorArrays(
+            int[] indices, double[] values, DataInputView source, int len) 
throws IOException {
+        for (int i = 0; i < len; i++) {
+            indices[i] = source.readInt();
+            values[i] = source.readDouble();
+        }
+    }
+
+    @Override
+    public SparseVector deserialize(DataInputView source) throws IOException {
+        int n = source.readInt();
+        int len = source.readInt();
+        int[] indices = new int[len];
+        double[] values = new double[len];
+        readSparseVectorArrays(indices, values, source, len);
+        return new SparseVector(n, indices, values);
+    }
+
+    @Override
+    public SparseVector deserialize(SparseVector reuse, DataInputView source) 
throws IOException {
+        int n = source.readInt();
+        int len = source.readInt();
+        if (reuse.n == n && reuse.values.length == len) {
+            readSparseVectorArrays(reuse.indices, reuse.values, source, len);
+            return reuse;
+        }
+
+        int[] indices = new int[len];
+        double[] values = new double[len];
+        readSparseVectorArrays(indices, values, source, len);
+        return new SparseVector(n, indices, values);
+    }
+
+    @Override
+    public void copy(DataInputView source, DataOutputView target) throws 
IOException {
+        int n = source.readInt();
+        int len = source.readInt();
+
+        target.writeInt(n);
+        target.writeInt(len);
+
+        target.write(source, len * 12);
+    }
+
+    @Override
+    public TypeSerializerSnapshot<SparseVector> snapshotConfiguration() {
+        return new SparseVectorSerializerSnapshot();
+    }
+
+    /** Serializer configuration snapshot for compatibility and format 
evolution. */
+    @SuppressWarnings("WeakerAccess")
+    public static final class SparseVectorSerializerSnapshot
+            extends SimpleTypeSerializerSnapshot<SparseVector> {
+
+        public SparseVectorSerializerSnapshot() {
+            super(() -> INSTANCE);
+        }
+    }
+}
diff --git 
a/flink-ml-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/DenseVectorTypeInfo.java
 
b/flink-ml-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/SparseVectorTypeInfo.java
similarity index 50%
copy from 
flink-ml-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/DenseVectorTypeInfo.java
copy to 
flink-ml-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/SparseVectorTypeInfo.java
index 0239e17..06686f0 100644
--- 
a/flink-ml-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/DenseVectorTypeInfo.java
+++ 
b/flink-ml-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/SparseVectorTypeInfo.java
@@ -7,13 +7,14 @@
  * "License"); you may not use this file except in compliance
  * with the License.  You may obtain a copy of the License at
  *
- *     http://www.apache.org/licenses/LICENSE-2.0
+ *   http://www.apache.org/licenses/LICENSE-2.0
  *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
  */
 
 package org.apache.flink.ml.linalg.typeinfo;
@@ -21,39 +22,35 @@ package org.apache.flink.ml.linalg.typeinfo;
 import org.apache.flink.api.common.ExecutionConfig;
 import org.apache.flink.api.common.typeinfo.TypeInformation;
 import org.apache.flink.api.common.typeutils.TypeSerializer;
-import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.SparseVector;
 
-/** A {@link TypeInformation} for the {@link DenseVector} type. */
-public class DenseVectorTypeInfo extends TypeInformation<DenseVector> {
-    private static final long serialVersionUID = 1L;
-
-    public static final DenseVectorTypeInfo INSTANCE = new 
DenseVectorTypeInfo();
-
-    public DenseVectorTypeInfo() {}
+/** A {@link TypeInformation} for the {@link SparseVector} type. */
+public class SparseVectorTypeInfo extends TypeInformation<SparseVector> {
+    public static final SparseVectorTypeInfo INSTANCE = new 
SparseVectorTypeInfo();
 
     @Override
-    public int getArity() {
-        return 1;
+    public boolean isBasicType() {
+        return false;
     }
 
     @Override
-    public int getTotalFields() {
-        return 1;
+    public boolean isTupleType() {
+        return false;
     }
 
     @Override
-    public Class<DenseVector> getTypeClass() {
-        return DenseVector.class;
+    public int getArity() {
+        return 3;
     }
 
     @Override
-    public boolean isBasicType() {
-        return false;
+    public int getTotalFields() {
+        return 3;
     }
 
     @Override
-    public boolean isTupleType() {
-        return false;
+    public Class<SparseVector> getTypeClass() {
+        return SparseVector.class;
     }
 
     @Override
@@ -62,30 +59,27 @@ public class DenseVectorTypeInfo extends 
TypeInformation<DenseVector> {
     }
 
     @Override
-    @SuppressWarnings("unchecked")
-    public TypeSerializer<DenseVector> createSerializer(ExecutionConfig 
executionConfig) {
-        return new DenseVectorSerializer();
+    public TypeSerializer<SparseVector> createSerializer(ExecutionConfig 
executionConfig) {
+        return SparseVectorSerializer.INSTANCE;
     }
 
-    // 
--------------------------------------------------------------------------------------------
-
     @Override
-    public int hashCode() {
-        return getClass().hashCode();
+    public String toString() {
+        return "SparseVectorType";
     }
 
     @Override
     public boolean equals(Object obj) {
-        return obj instanceof DenseVectorTypeInfo;
+        return obj instanceof SparseVectorTypeInfo;
     }
 
     @Override
-    public boolean canEqual(Object obj) {
-        return obj instanceof DenseVectorTypeInfo;
+    public int hashCode() {
+        return getClass().hashCode();
     }
 
     @Override
-    public String toString() {
-        return "DenseVectorType";
+    public boolean canEqual(Object obj) {
+        return obj instanceof SparseVectorTypeInfo;
     }
 }
diff --git 
a/flink-ml-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/SparseVectorTypeInfoFactory.java
 
b/flink-ml-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/SparseVectorTypeInfoFactory.java
new file mode 100644
index 0000000..01c1036
--- /dev/null
+++ 
b/flink-ml-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/SparseVectorTypeInfoFactory.java
@@ -0,0 +1,40 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.flink.ml.linalg.typeinfo;
+
+import org.apache.flink.api.common.typeinfo.TypeInfoFactory;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.typeutils.TypeExtractor;
+import org.apache.flink.ml.linalg.SparseVector;
+
+import java.lang.reflect.Type;
+import java.util.Map;
+
+/**
+ * Used by {@link TypeExtractor} to create a {@link TypeInformation} for 
implementations of {@link
+ * SparseVector}.
+ */
+public class SparseVectorTypeInfoFactory extends TypeInfoFactory<SparseVector> 
{
+    @Override
+    public TypeInformation<SparseVector> createTypeInfo(
+            Type type, Map<String, TypeInformation<?>> map) {
+        return new SparseVectorTypeInfo();
+    }
+}
diff --git 
a/flink-ml-core/src/main/java/org/apache/flink/ml/param/ParamValidators.java 
b/flink-ml-core/src/main/java/org/apache/flink/ml/param/ParamValidators.java
index 925ccb2..e7d1436 100644
--- a/flink-ml-core/src/main/java/org/apache/flink/ml/param/ParamValidators.java
+++ b/flink-ml-core/src/main/java/org/apache/flink/ml/param/ParamValidators.java
@@ -95,4 +95,9 @@ public class ParamValidators {
             }
         };
     }
+
+    // Check if the parameter value array is not empty array.
+    public static <T> ParamValidator<T[]> nonEmptyArray() {
+        return value -> value != null && value.length > 0;
+    }
 }
diff --git a/flink-ml-core/src/test/java/org/apache/flink/ml/api/StageTest.java 
b/flink-ml-core/src/test/java/org/apache/flink/ml/api/StageTest.java
index 6ac630d..df0db64 100644
--- a/flink-ml-core/src/test/java/org/apache/flink/ml/api/StageTest.java
+++ b/flink-ml-core/src/test/java/org/apache/flink/ml/api/StageTest.java
@@ -388,5 +388,10 @@ public class StageTest {
         ParamValidator<Integer> notNull = ParamValidators.notNull();
         Assert.assertTrue(notNull.validate(5));
         Assert.assertFalse(notNull.validate(null));
+
+        ParamValidator<Object[]> nonEmptyArray = 
ParamValidators.nonEmptyArray();
+        Assert.assertTrue(nonEmptyArray.validate(new String[] {"1"}));
+        Assert.assertFalse(nonEmptyArray.validate(null));
+        Assert.assertFalse(nonEmptyArray.validate(new String[0]));
     }
 }
diff --git 
a/flink-ml-core/src/test/java/org/apache/flink/ml/linalg/SparseVectorTest.java 
b/flink-ml-core/src/test/java/org/apache/flink/ml/linalg/SparseVectorTest.java
new file mode 100644
index 0000000..0e7c349
--- /dev/null
+++ 
b/flink-ml-core/src/test/java/org/apache/flink/ml/linalg/SparseVectorTest.java
@@ -0,0 +1,132 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.linalg;
+
+import org.apache.flink.core.memory.DataInputViewStreamWrapper;
+import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
+import org.apache.flink.ml.linalg.typeinfo.SparseVectorSerializer;
+
+import org.apache.commons.io.output.ByteArrayOutputStream;
+import org.junit.Assert;
+import org.junit.Test;
+
+import java.io.ByteArrayInputStream;
+import java.io.IOException;
+
+import static org.junit.Assert.assertArrayEquals;
+import static org.junit.Assert.assertEquals;
+
+/** Tests the behavior of Vectors. */
+public class SparseVectorTest {
+    @Test
+    public void testConstructor() {
+        int n = 4;
+        int[] indices = new int[] {0, 2, 3};
+        double[] values = new double[] {0.1, 0.3, 0.4};
+
+        SparseVector vector = Vectors.sparse(n, indices, values);
+        assertEquals(n, vector.n);
+        assertArrayEquals(indices, vector.indices);
+        assertArrayEquals(values, vector.values, 1e-5);
+        assertEquals("(4, [0, 2, 3], [0.1, 0.3, 0.4])", vector.toString());
+    }
+
+    @Test
+    public void testDuplicateIndex() {
+        int n = 4;
+        int[] indices = new int[] {0, 2, 2};
+        double[] values = new double[] {0.1, 0.3, 0.4};
+
+        try {
+            Vectors.sparse(n, indices, values);
+            Assert.fail("Expected IllegalArgumentException.");
+        } catch (Exception e) {
+            assertEquals(IllegalArgumentException.class, e.getClass());
+            assertEquals("Indices duplicated.", e.getMessage());
+        }
+    }
+
+    @Test
+    public void testAllZeroVector() {
+        int n = 4;
+        SparseVector vector = Vectors.sparse(n, new int[0], new double[0]);
+        assertArrayEquals(vector.toArray(), new double[n], 1e-5);
+    }
+
+    @Test
+    public void testUnsortedIndex() {
+        SparseVector vector;
+
+        vector = Vectors.sparse(4, new int[] {2}, new double[] {0.3});
+        assertEquals(4, vector.n);
+        assertArrayEquals(new int[] {2}, vector.indices);
+        assertArrayEquals(new double[] {0.3}, vector.values, 1e-5);
+
+        vector = Vectors.sparse(4, new int[] {1, 2}, new double[] {0.2, 0.3});
+        assertEquals(4, vector.n);
+        assertArrayEquals(new int[] {1, 2}, vector.indices);
+        assertArrayEquals(new double[] {0.2, 0.3}, vector.values, 1e-5);
+
+        vector = Vectors.sparse(4, new int[] {2, 1}, new double[] {0.3, 0.2});
+        assertEquals(4, vector.n);
+        assertArrayEquals(new int[] {1, 2}, vector.indices);
+        assertArrayEquals(new double[] {0.2, 0.3}, vector.values, 1e-5);
+
+        vector = Vectors.sparse(4, new int[] {3, 2, 0}, new double[] {0.4, 
0.3, 0.1});
+        assertEquals(4, vector.n);
+        assertArrayEquals(new int[] {0, 2, 3}, vector.indices);
+        assertArrayEquals(new double[] {0.1, 0.3, 0.4}, vector.values, 1e-5);
+
+        vector = Vectors.sparse(4, new int[] {2, 0, 3}, new double[] {0.3, 
0.1, 0.4});
+        assertEquals(4, vector.n);
+        assertArrayEquals(new int[] {0, 2, 3}, vector.indices);
+        assertArrayEquals(new double[] {0.1, 0.3, 0.4}, vector.values, 1e-5);
+
+        vector =
+                Vectors.sparse(
+                        7,
+                        new int[] {6, 5, 4, 3, 2, 1, 0},
+                        new double[] {0.7, 0.6, 0.5, 0.4, 0.3, 0.2, 0.1});
+        assertEquals(7, vector.n);
+        assertArrayEquals(new int[] {0, 1, 2, 3, 4, 5, 6}, vector.indices);
+        assertArrayEquals(new double[] {0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7}, 
vector.values, 1e-5);
+    }
+
+    @Test
+    public void testSerializer() throws IOException {
+        int n = 4;
+        int[] indices = new int[] {0, 2, 3};
+        double[] values = new double[] {0.1, 0.3, 0.4};
+        SparseVector vector = Vectors.sparse(n, indices, values);
+        SparseVectorSerializer serializer = SparseVectorSerializer.INSTANCE;
+
+        ByteArrayOutputStream bOutput = new ByteArrayOutputStream(1024);
+        DataOutputViewStreamWrapper output = new 
DataOutputViewStreamWrapper(bOutput);
+        serializer.serialize(vector, output);
+
+        byte[] b = bOutput.toByteArray();
+        ByteArrayInputStream bInput = new ByteArrayInputStream(b);
+        DataInputViewStreamWrapper input = new 
DataInputViewStreamWrapper(bInput);
+        SparseVector vector2 = serializer.deserialize(input);
+
+        assertEquals(vector.n, vector2.n);
+        assertArrayEquals(vector.indices, vector2.indices);
+        assertArrayEquals(vector.values, vector2.values, 1e-5);
+    }
+}
diff --git 
a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasHandleInvalid.java
 
b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasHandleInvalid.java
new file mode 100644
index 0000000..a7ea41a
--- /dev/null
+++ 
b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasHandleInvalid.java
@@ -0,0 +1,56 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.param;
+
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.param.ParamValidators;
+import org.apache.flink.ml.param.StringParam;
+import org.apache.flink.ml.param.WithParams;
+
+/**
+ * Interface for the shared handleInvalid param.
+ *
+ * <p>Supported options and the corresponding behavior to handle invalid 
entries is listed as
+ * follows.
+ *
+ * <ul>
+ *   <li>error: raise an exception.
+ *   <li>skip: filter out rows with bad values
+ * </ul>
+ */
+public interface HasHandleInvalid<T> extends WithParams<T> {
+    String ERROR_INVALID = "error";
+    String SKIP_INVALID = "skip";
+
+    Param<String> HANDLE_INVALID =
+            new StringParam(
+                    "handleInvalid",
+                    "Strategy to handle invalid entries.",
+                    ERROR_INVALID,
+                    ParamValidators.inArray(ERROR_INVALID, SKIP_INVALID));
+
+    default String getHandleInvalid() {
+        return get(HANDLE_INVALID);
+    }
+
+    default T setHandleInvalid(String value) {
+        set(HANDLE_INVALID, value);
+        return (T) this;
+    }
+}
diff --git 
a/flink-ml-core/src/main/java/org/apache/flink/ml/linalg/Vectors.java 
b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasInputCols.java
similarity index 55%
copy from flink-ml-core/src/main/java/org/apache/flink/ml/linalg/Vectors.java
copy to 
flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasInputCols.java
index a058755..c567de7 100644
--- a/flink-ml-core/src/main/java/org/apache/flink/ml/linalg/Vectors.java
+++ 
b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasInputCols.java
@@ -16,13 +16,24 @@
  * limitations under the License.
  */
 
-package org.apache.flink.ml.linalg;
+package org.apache.flink.ml.common.param;
 
-/** Utility methods for instantiating Vector. */
-public class Vectors {
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.param.ParamValidators;
+import org.apache.flink.ml.param.StringArrayParam;
+import org.apache.flink.ml.param.WithParams;
 
-    /** Creates a dense vector from its values. */
-    public static DenseVector dense(double... values) {
-        return new DenseVector(values);
+/** Interface for the shared inputCols param. */
+public interface HasInputCols<T> extends WithParams<T> {
+    Param<String[]> INPUT_COLS =
+            new StringArrayParam(
+                    "inputCols", "Input column names.", null, 
ParamValidators.nonEmptyArray());
+
+    default String[] getInputCols() {
+        return get(INPUT_COLS);
+    }
+
+    default T setInputCols(String... value) {
+        return set(INPUT_COLS, value);
     }
 }
diff --git 
a/flink-ml-core/src/main/java/org/apache/flink/ml/linalg/Vectors.java 
b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasOutputCols.java
similarity index 54%
copy from flink-ml-core/src/main/java/org/apache/flink/ml/linalg/Vectors.java
copy to 
flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasOutputCols.java
index a058755..947501f 100644
--- a/flink-ml-core/src/main/java/org/apache/flink/ml/linalg/Vectors.java
+++ 
b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasOutputCols.java
@@ -16,13 +16,24 @@
  * limitations under the License.
  */
 
-package org.apache.flink.ml.linalg;
+package org.apache.flink.ml.common.param;
 
-/** Utility methods for instantiating Vector. */
-public class Vectors {
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.param.ParamValidators;
+import org.apache.flink.ml.param.StringArrayParam;
+import org.apache.flink.ml.param.WithParams;
 
-    /** Creates a dense vector from its values. */
-    public static DenseVector dense(double... values) {
-        return new DenseVector(values);
+/** Interface for the shared outputCols param. */
+public interface HasOutputCols<T> extends WithParams<T> {
+    Param<String[]> OUTPUT_COLS =
+            new StringArrayParam(
+                    "outputCols", "Output column names.", null, 
ParamValidators.nonEmptyArray());
+
+    default String[] getOutputCols() {
+        return get(OUTPUT_COLS);
+    }
+
+    default T setOutputCols(String... value) {
+        return set(OUTPUT_COLS, value);
     }
 }
diff --git 
a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/onehotencoder/OneHotEncoder.java
 
b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/onehotencoder/OneHotEncoder.java
new file mode 100644
index 0000000..374d457
--- /dev/null
+++ 
b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/onehotencoder/OneHotEncoder.java
@@ -0,0 +1,148 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.feature.onehotencoder;
+
+import org.apache.flink.api.common.functions.FlatMapFunction;
+import org.apache.flink.api.common.functions.MapPartitionFunction;
+import org.apache.flink.api.common.typeinfo.Types;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.datastream.MapPartitionFunctionWrapper;
+import org.apache.flink.ml.common.param.HasHandleInvalid;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import java.io.IOException;
+import java.util.HashMap;
+import java.util.Map;
+
+/**
+ * An Estimator which implements the one-hot encoding algorithm.
+ *
+ * <p>Data of selected input columns should be indexed numbers in order for 
OneHotEncoder to
+ * function correctly.
+ *
+ * <p>See https://en.wikipedia.org/wiki/One-hot.
+ */
+public class OneHotEncoder
+        implements Estimator<OneHotEncoder, OneHotEncoderModel>,
+                OneHotEncoderParams<OneHotEncoder> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+
+    public OneHotEncoder() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    public OneHotEncoderModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        
Preconditions.checkArgument(getHandleInvalid().equals(HasHandleInvalid.ERROR_INVALID));
+
+        final String[] inputCols = getInputCols();
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) 
inputs[0]).getTableEnvironment();
+        DataStream<Tuple2<Integer, Integer>> modelData =
+                tEnv.toDataStream(inputs[0])
+                        .flatMap(new ExtractInputColsValueFunction(inputCols))
+                        .keyBy(columnIdAndValue -> columnIdAndValue.f0)
+                        .transform(
+                                "findMaxIndex",
+                                Types.TUPLE(Types.INT, Types.INT),
+                                new MapPartitionFunctionWrapper<>(new 
FindMaxIndexFunction()));
+
+        OneHotEncoderModel model =
+                new 
OneHotEncoderModel().setModelData(tEnv.fromDataStream(modelData));
+        ReadWriteUtils.updateExistingParams(model, paramMap);
+        return model;
+    }
+
+    @Override
+    public void save(String path) throws IOException {
+        ReadWriteUtils.saveMetadata(this, path);
+    }
+
+    public static OneHotEncoder load(StreamExecutionEnvironment env, String 
path)
+            throws IOException {
+        return ReadWriteUtils.loadStageParam(path);
+    }
+
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return paramMap;
+    }
+
+    /**
+     * Extract values of input columns of input data.
+     *
+     * <p>Input: rows of input data containing designated input columns
+     *
+     * <p>Output: Pairs of column index and value stored in those columns
+     */
+    private static class ExtractInputColsValueFunction
+            implements FlatMapFunction<Row, Tuple2<Integer, Integer>> {
+        private final String[] inputCols;
+
+        private ExtractInputColsValueFunction(String[] inputCols) {
+            this.inputCols = inputCols;
+        }
+
+        @Override
+        public void flatMap(Row row, Collector<Tuple2<Integer, Integer>> 
collector) {
+            for (int i = 0; i < inputCols.length; i++) {
+                Number number = (Number) row.getField(inputCols[i]);
+                Preconditions.checkArgument(
+                        number.intValue() == number.doubleValue(),
+                        String.format("Value %s cannot be parsed as indexed 
integer.", number));
+                Preconditions.checkArgument(
+                        number.intValue() >= 0, "Negative value not 
supported.");
+                collector.collect(new Tuple2<>(i, number.intValue()));
+            }
+        }
+    }
+
+    /** Function to find the max index value for each column. */
+    private static class FindMaxIndexFunction
+            implements MapPartitionFunction<Tuple2<Integer, Integer>, 
Tuple2<Integer, Integer>> {
+
+        @Override
+        public void mapPartition(
+                Iterable<Tuple2<Integer, Integer>> iterable,
+                Collector<Tuple2<Integer, Integer>> collector) {
+            Map<Integer, Integer> map = new HashMap<>();
+            for (Tuple2<Integer, Integer> value : iterable) {
+                map.put(
+                        value.f0,
+                        Math.max(map.getOrDefault(value.f0, 
Integer.MIN_VALUE), value.f1));
+            }
+            for (Map.Entry<Integer, Integer> entry : map.entrySet()) {
+                collector.collect(new Tuple2<>(entry.getKey(), 
entry.getValue()));
+            }
+        }
+    }
+}
diff --git 
a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/onehotencoder/OneHotEncoderModel.java
 
b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/onehotencoder/OneHotEncoderModel.java
new file mode 100644
index 0000000..447fe77
--- /dev/null
+++ 
b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/onehotencoder/OneHotEncoderModel.java
@@ -0,0 +1,190 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.feature.onehotencoder;
+
+import org.apache.flink.api.common.functions.RichMapFunction;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.ml.api.Model;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.common.datastream.TableUtils;
+import org.apache.flink.ml.common.param.HasHandleInvalid;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.table.runtime.typeutils.ExternalTypeInfo;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Vector;
+import java.util.function.Function;
+
+/**
+ * A Model which encodes data into one-hot format using the model data 
computed by {@link
+ * OneHotEncoder}.
+ */
+public class OneHotEncoderModel
+        implements Model<OneHotEncoderModel>, 
OneHotEncoderParams<OneHotEncoderModel> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table modelDataTable;
+
+    public OneHotEncoderModel() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    public Table[] transform(Table... inputs) {
+        final String[] inputCols = getInputCols();
+        final String[] outputCols = getOutputCols();
+        final boolean dropLast = getDropLast();
+        final String broadcastModelKey = "OneHotModelStream";
+
+        
Preconditions.checkArgument(getHandleInvalid().equals(HasHandleInvalid.ERROR_INVALID));
+        Preconditions.checkArgument(inputs.length == 1);
+        Preconditions.checkArgument(inputCols.length == outputCols.length);
+
+        RowTypeInfo inputTypeInfo = 
TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema());
+        RowTypeInfo outputTypeInfo =
+                new RowTypeInfo(
+                        ArrayUtils.addAll(
+                                inputTypeInfo.getFieldTypes(),
+                                Collections.nCopies(
+                                                outputCols.length,
+                                                
ExternalTypeInfo.of(Vector.class))
+                                        .toArray(new TypeInformation[0])),
+                        ArrayUtils.addAll(inputTypeInfo.getFieldNames(), 
outputCols));
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) 
modelDataTable).getTableEnvironment();
+        DataStream<Row> input = tEnv.toDataStream(inputs[0]);
+        DataStream<Tuple2<Integer, Integer>> modelStream =
+                OneHotEncoderModelData.getModelDataStream(modelDataTable);
+
+        Function<List<DataStream<?>>, DataStream<Row>> function =
+                dataStreams -> {
+                    DataStream stream = dataStreams.get(0);
+                    return stream.map(
+                            new GenerateOutputsFunction(inputCols, dropLast, 
broadcastModelKey),
+                            outputTypeInfo);
+                };
+
+        DataStream<Row> output =
+                BroadcastUtils.withBroadcastStream(
+                        Collections.singletonList(input),
+                        Collections.singletonMap(broadcastModelKey, 
modelStream),
+                        function);
+
+        Table outputTable = tEnv.fromDataStream(output);
+
+        return new Table[] {outputTable};
+    }
+
+    @Override
+    public void save(String path) throws IOException {
+        ReadWriteUtils.saveMetadata(this, path);
+        ReadWriteUtils.saveModelData(
+                OneHotEncoderModelData.getModelDataStream(modelDataTable),
+                path,
+                new OneHotEncoderModelData.ModelDataEncoder());
+    }
+
+    public static OneHotEncoderModel load(StreamExecutionEnvironment env, 
String path)
+            throws IOException {
+        StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
+        OneHotEncoderModel model = ReadWriteUtils.loadStageParam(path);
+        DataStream<Tuple2<Integer, Integer>> modelData =
+                ReadWriteUtils.loadModelData(
+                        env, path, new 
OneHotEncoderModelData.ModelDataStreamFormat());
+        return model.setModelData(tEnv.fromDataStream(modelData));
+    }
+
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return paramMap;
+    }
+
+    @Override
+    public OneHotEncoderModel setModelData(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        modelDataTable = inputs[0];
+        return this;
+    }
+
+    @Override
+    public Table[] getModelData() {
+        return new Table[] {modelDataTable};
+    }
+
+    private static class GenerateOutputsFunction extends RichMapFunction<Row, 
Row> {
+        private final String[] inputCols;
+        private final boolean dropLast;
+        private final String broadcastModelKey;
+        private List<Tuple2<Integer, Integer>> model = null;
+
+        public GenerateOutputsFunction(
+                String[] inputCols, boolean dropLast, String 
broadcastModelKey) {
+            this.inputCols = inputCols;
+            this.dropLast = dropLast;
+            this.broadcastModelKey = broadcastModelKey;
+        }
+
+        @Override
+        public Row map(Row row) {
+            if (model == null) {
+                model = 
getRuntimeContext().getBroadcastVariable(broadcastModelKey);
+            }
+            int[] categorySizes = new int[model.size()];
+            int offset = dropLast ? 0 : 1;
+            for (Tuple2<Integer, Integer> tup : model) {
+                categorySizes[tup.f0] = tup.f1 + offset;
+            }
+            Row result = new Row(categorySizes.length);
+            for (int i = 0; i < categorySizes.length; i++) {
+                Number number = (Number) row.getField(inputCols[i]);
+                Preconditions.checkArgument(
+                        number.intValue() == number.doubleValue(),
+                        String.format("Value %s cannot be parsed as indexed 
integer.", number));
+                int idx = number.intValue();
+                if (idx == categorySizes[i]) {
+                    result.setField(i, Vectors.sparse(categorySizes[i], new 
int[0], new double[0]));
+                } else {
+                    result.setField(
+                            i,
+                            Vectors.sparse(categorySizes[i], new int[] {idx}, 
new double[] {1.0}));
+                }
+            }
+
+            return Row.join(row, result);
+        }
+    }
+}
diff --git 
a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/onehotencoder/OneHotEncoderModelData.java
 
b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/onehotencoder/OneHotEncoderModelData.java
new file mode 100644
index 0000000..f267784
--- /dev/null
+++ 
b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/onehotencoder/OneHotEncoderModelData.java
@@ -0,0 +1,109 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.feature.onehotencoder;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.serialization.Encoder;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.common.typeinfo.Types;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.connector.file.src.reader.SimpleStreamFormat;
+import org.apache.flink.core.fs.FSDataInputStream;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+
+import com.esotericsoftware.kryo.io.Input;
+import com.esotericsoftware.kryo.io.Output;
+
+import java.io.IOException;
+import java.io.OutputStream;
+
+/**
+ * Model data of {@link OneHotEncoderModel}.
+ *
+ * <p>This class also provides methods to convert model data from Table to 
Datastream, and classes
+ * to save/load model data.
+ */
+public class OneHotEncoderModelData {
+    /**
+     * Converts the table model to a data stream.
+     *
+     * @param modelData The table model data.
+     * @return The data stream model data.
+     */
+    public static DataStream<Tuple2<Integer, Integer>> 
getModelDataStream(Table modelData) {
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) 
modelData).getTableEnvironment();
+        return tEnv.toDataStream(modelData)
+                .map(
+                        new MapFunction<Row, Tuple2<Integer, Integer>>() {
+                            @Override
+                            public Tuple2<Integer, Integer> map(Row row) {
+                                return new Tuple2<>(
+                                        (int) row.getField("f0"), (int) 
row.getField("f1"));
+                            }
+                        });
+    }
+
+    /** Data encoder for the OneHotEncoder model data. */
+    public static class ModelDataEncoder implements Encoder<Tuple2<Integer, 
Integer>> {
+        @Override
+        public void encode(Tuple2<Integer, Integer> modelData, OutputStream 
outputStream) {
+            Output output = new Output(outputStream);
+            output.writeInt(modelData.f0);
+            output.writeInt(modelData.f1);
+            output.flush();
+        }
+    }
+
+    /** Data decoder for the OneHotEncoder model data. */
+    public static class ModelDataStreamFormat extends 
SimpleStreamFormat<Tuple2<Integer, Integer>> {
+        @Override
+        public Reader<Tuple2<Integer, Integer>> createReader(
+                Configuration config, FSDataInputStream stream) {
+            return new Reader<Tuple2<Integer, Integer>>() {
+                private final Input input = new Input(stream);
+
+                @Override
+                public Tuple2<Integer, Integer> read() {
+                    if (input.eof()) {
+                        return null;
+                    }
+                    int f0 = input.readInt();
+                    int f1 = input.readInt();
+                    return new Tuple2<>(f0, f1);
+                }
+
+                @Override
+                public void close() throws IOException {
+                    stream.close();
+                }
+            };
+        }
+
+        @Override
+        public TypeInformation<Tuple2<Integer, Integer>> getProducedType() {
+            return Types.TUPLE(Types.INT, Types.INT);
+        }
+    }
+}
diff --git 
a/flink-ml-core/src/main/java/org/apache/flink/ml/linalg/Vectors.java 
b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/onehotencoder/OneHotEncoderParams.java
similarity index 51%
copy from flink-ml-core/src/main/java/org/apache/flink/ml/linalg/Vectors.java
copy to 
flink-ml-lib/src/main/java/org/apache/flink/ml/feature/onehotencoder/OneHotEncoderParams.java
index a058755..9b57159 100644
--- a/flink-ml-core/src/main/java/org/apache/flink/ml/linalg/Vectors.java
+++ 
b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/onehotencoder/OneHotEncoderParams.java
@@ -16,13 +16,29 @@
  * limitations under the License.
  */
 
-package org.apache.flink.ml.linalg;
+package org.apache.flink.ml.feature.onehotencoder;
 
-/** Utility methods for instantiating Vector. */
-public class Vectors {
+import org.apache.flink.ml.common.param.HasHandleInvalid;
+import org.apache.flink.ml.common.param.HasInputCols;
+import org.apache.flink.ml.common.param.HasOutputCols;
+import org.apache.flink.ml.param.BooleanParam;
+import org.apache.flink.ml.param.Param;
 
-    /** Creates a dense vector from its values. */
-    public static DenseVector dense(double... values) {
-        return new DenseVector(values);
+/**
+ * Params of OneHotEncoderModel.
+ *
+ * @param <T> The class type of this instance.
+ */
+public interface OneHotEncoderParams<T>
+        extends HasInputCols<T>, HasOutputCols<T>, HasHandleInvalid<T> {
+    Param<Boolean> DROP_LAST =
+            new BooleanParam("dropLast", "Whether to drop the last category.", 
true);
+
+    default boolean getDropLast() {
+        return get(DROP_LAST);
+    }
+
+    default T setDropLast(boolean value) {
+        return set(DROP_LAST, value);
     }
 }
diff --git 
a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/OneHotEncoderTest.java 
b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/OneHotEncoderTest.java
new file mode 100644
index 0000000..51f9735
--- /dev/null
+++ 
b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/OneHotEncoderTest.java
@@ -0,0 +1,294 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.feature;
+
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.ml.common.param.HasHandleInvalid;
+import org.apache.flink.ml.feature.onehotencoder.OneHotEncoder;
+import org.apache.flink.ml.feature.onehotencoder.OneHotEncoderModel;
+import org.apache.flink.ml.feature.onehotencoder.OneHotEncoderModelData;
+import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.ml.util.StageTestUtils;
+import 
org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.CloseableIterator;
+
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+import static org.junit.Assert.assertArrayEquals;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertTrue;
+
+/** Tests OneHotEncoder and OneHotEncoderModel. */
+public class OneHotEncoderTest {
+    @Rule public final TemporaryFolder tempFolder = new TemporaryFolder();
+
+    private StreamExecutionEnvironment env;
+    private StreamTableEnvironment tEnv;
+    private Table trainTable;
+    private Table predictTable;
+    private Map<Double, Vector>[] expectedOutput;
+    private OneHotEncoder estimator;
+
+    @Before
+    public void before() {
+        Configuration config = new Configuration();
+        
config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, 
true);
+        env = StreamExecutionEnvironment.getExecutionEnvironment(config);
+        env.setParallelism(4);
+        env.enableCheckpointing(100);
+        env.setRestartStrategy(RestartStrategies.noRestart());
+        tEnv = StreamTableEnvironment.create(env);
+
+        List<Row> trainData = Arrays.asList(Row.of(0.0), Row.of(1.0), 
Row.of(2.0), Row.of(0.0));
+
+        trainTable = 
tEnv.fromDataStream(env.fromCollection(trainData)).as("input");
+
+        List<Row> predictData = Arrays.asList(Row.of(0.0), Row.of(1.0), 
Row.of(2.0));
+
+        predictTable = 
tEnv.fromDataStream(env.fromCollection(predictData)).as("input");
+
+        expectedOutput =
+                new HashMap[] {
+                    new HashMap<Double, Vector>() {
+                        {
+                            put(0.0, Vectors.sparse(2, new int[] {0}, new 
double[] {1.0}));
+                            put(1.0, Vectors.sparse(2, new int[] {1}, new 
double[] {1.0}));
+                            put(2.0, Vectors.sparse(2, new int[0], new 
double[0]));
+                        }
+                    }
+                };
+
+        estimator = new 
OneHotEncoder().setInputCols("input").setOutputCols("output");
+    }
+
+    /**
+     * Executes a given table and collect its results. Results are returned as 
a map array. Each
+     * element in the array is a map corresponding to a input column whose key 
is the original value
+     * in the input column, value is the one-hot encoding result of that value.
+     *
+     * @param table A table to be executed and to have its result collected
+     * @param inputCols Name of the input columns
+     * @param outputCols Name of the output columns containing one-hot 
encoding result
+     * @return An array of map containing the collected results for each input 
column
+     */
+    private static Map<Double, Vector>[] executeAndCollect(
+            Table table, String[] inputCols, String[] outputCols) {
+        Map<Double, Vector>[] maps = new HashMap[inputCols.length];
+        for (int i = 0; i < inputCols.length; i++) {
+            maps[i] = new HashMap<>();
+        }
+        for (CloseableIterator<Row> it = table.execute().collect(); 
it.hasNext(); ) {
+            Row row = it.next();
+            for (int i = 0; i < inputCols.length; i++) {
+                maps[i].put(
+                        ((Number) row.getField(inputCols[i])).doubleValue(),
+                        (Vector) row.getField(outputCols[i]));
+            }
+        }
+        return maps;
+    }
+
+    @Test
+    public void testParam() {
+        OneHotEncoder estimator = new OneHotEncoder();
+
+        assertTrue(estimator.getDropLast());
+
+        
estimator.setInputCols("test_input").setOutputCols("test_output").setDropLast(false);
+
+        assertArrayEquals(new String[] {"test_input"}, 
estimator.getInputCols());
+        assertArrayEquals(new String[] {"test_output"}, 
estimator.getOutputCols());
+        assertFalse(estimator.getDropLast());
+
+        OneHotEncoderModel model = new OneHotEncoderModel();
+
+        assertTrue(model.getDropLast());
+
+        
model.setInputCols("test_input").setOutputCols("test_output").setDropLast(false);
+
+        assertArrayEquals(new String[] {"test_input"}, model.getInputCols());
+        assertArrayEquals(new String[] {"test_output"}, model.getOutputCols());
+        assertFalse(model.getDropLast());
+    }
+
+    @Test
+    public void testFitAndPredict() {
+        OneHotEncoderModel model = estimator.fit(trainTable);
+        Table outputTable = model.transform(predictTable)[0];
+        Map<Double, Vector>[] actualOutput =
+                executeAndCollect(outputTable, model.getInputCols(), 
model.getOutputCols());
+        assertArrayEquals(expectedOutput, actualOutput);
+    }
+
+    @Test
+    public void testDropLast() {
+        estimator.setDropLast(false);
+
+        expectedOutput =
+                new HashMap[] {
+                    new HashMap<Double, Vector>() {
+                        {
+                            put(0.0, Vectors.sparse(3, new int[] {0}, new 
double[] {1.0}));
+                            put(1.0, Vectors.sparse(3, new int[] {1}, new 
double[] {1.0}));
+                            put(2.0, Vectors.sparse(3, new int[] {2}, new 
double[] {1.0}));
+                        }
+                    }
+                };
+
+        OneHotEncoderModel model = estimator.fit(trainTable);
+        Table outputTable = model.transform(predictTable)[0];
+        Map<Double, Vector>[] actualOutput =
+                executeAndCollect(outputTable, model.getInputCols(), 
model.getOutputCols());
+        assertArrayEquals(expectedOutput, actualOutput);
+    }
+
+    @Test
+    public void testInputDataType() {
+        List<Row> trainData = Arrays.asList(Row.of(0), Row.of(1), Row.of(2), 
Row.of(0));
+
+        trainTable = 
tEnv.fromDataStream(env.fromCollection(trainData)).as("input");
+
+        List<Row> predictData = Arrays.asList(Row.of(0), Row.of(1), Row.of(2));
+        predictTable = 
tEnv.fromDataStream(env.fromCollection(predictData)).as("input");
+
+        expectedOutput =
+                new HashMap[] {
+                    new HashMap<Double, Vector>() {
+                        {
+                            put(0.0, Vectors.sparse(2, new int[] {0}, new 
double[] {1.0}));
+                            put(1.0, Vectors.sparse(2, new int[] {1}, new 
double[] {1.0}));
+                            put(2.0, Vectors.sparse(2, new int[0], new 
double[0]));
+                        }
+                    }
+                };
+
+        OneHotEncoderModel model = estimator.fit(trainTable);
+        Table outputTable = model.transform(predictTable)[0];
+        Map<Double, Vector>[] actualOutput =
+                executeAndCollect(outputTable, model.getInputCols(), 
model.getOutputCols());
+        assertArrayEquals(expectedOutput, actualOutput);
+    }
+
+    @Test
+    public void testNotSupportedHandleInvalidOptions() {
+        estimator.setHandleInvalid(HasHandleInvalid.SKIP_INVALID);
+        try {
+            estimator.fit(trainTable);
+            Assert.fail("Expected IllegalArgumentException");
+        } catch (Exception e) {
+            assertEquals(IllegalArgumentException.class, ((Throwable) 
e).getClass());
+        }
+    }
+
+    @Test
+    public void testNonIndexedTrainData() {
+        List<Row> trainData = Arrays.asList(Row.of(0.5), Row.of(1.0), 
Row.of(2.0), Row.of(0.0));
+
+        trainTable = 
tEnv.fromDataStream(env.fromCollection(trainData)).as("input");
+        OneHotEncoderModel model = estimator.fit(trainTable);
+        Table outputTable = model.transform(predictTable)[0];
+        try {
+            outputTable.execute().collect().next();
+            Assert.fail("Expected IllegalArgumentException");
+        } catch (Exception e) {
+            Throwable exception = e;
+            while (exception.getCause() != null) {
+                exception = exception.getCause();
+            }
+            assertEquals(IllegalArgumentException.class, exception.getClass());
+            assertEquals("Value 0.5 cannot be parsed as indexed integer.", 
exception.getMessage());
+        }
+    }
+
+    @Test
+    public void testNonIndexedPredictData() {
+        List<Row> predictData = Arrays.asList(Row.of(0.5), Row.of(1.0), 
Row.of(2.0), Row.of(0.0));
+
+        predictTable = 
tEnv.fromDataStream(env.fromCollection(predictData)).as("input");
+        OneHotEncoderModel model = estimator.fit(trainTable);
+        Table outputTable = model.transform(predictTable)[0];
+        try {
+            outputTable.execute().collect().next();
+            Assert.fail("Expected IllegalArgumentException");
+        } catch (Exception e) {
+            Throwable exception = e;
+            while (exception.getCause() != null) {
+                exception = exception.getCause();
+            }
+            assertEquals(IllegalArgumentException.class, exception.getClass());
+            assertEquals("Value 0.5 cannot be parsed as indexed integer.", 
exception.getMessage());
+        }
+    }
+
+    @Test
+    public void testSaveLoad() throws Exception {
+        estimator =
+                StageTestUtils.saveAndReload(
+                        env, estimator, 
tempFolder.newFolder().getAbsolutePath());
+        OneHotEncoderModel model = estimator.fit(trainTable);
+        model = StageTestUtils.saveAndReload(env, model, 
tempFolder.newFolder().getAbsolutePath());
+        Table outputTable = model.transform(predictTable)[0];
+        Map<Double, Vector>[] actualOutput =
+                executeAndCollect(outputTable, model.getInputCols(), 
model.getOutputCols());
+        assertArrayEquals(expectedOutput, actualOutput);
+    }
+
+    @Test
+    public void testGetModelData() throws Exception {
+        OneHotEncoderModel model = estimator.fit(trainTable);
+        Tuple2<Integer, Integer> expected = new Tuple2<>(0, 2);
+        Tuple2<Integer, Integer> actual =
+                
OneHotEncoderModelData.getModelDataStream(model.getModelData()[0])
+                        .executeAndCollect()
+                        .next();
+        assertEquals(expected, actual);
+    }
+
+    @Test
+    public void testSetModelData() {
+        OneHotEncoderModel modelA = estimator.fit(trainTable);
+
+        Table modelData = modelA.getModelData()[0];
+        OneHotEncoderModel modelB = new 
OneHotEncoderModel().setModelData(modelData);
+        ReadWriteUtils.updateExistingParams(modelB, modelA.getParamMap());
+
+        Table outputTable = modelB.transform(predictTable)[0];
+        Map<Double, Vector>[] actualOutput =
+                executeAndCollect(outputTable, modelB.getInputCols(), 
modelB.getOutputCols());
+        assertArrayEquals(expectedOutput, actualOutput);
+    }
+}

Reply via email to