Repository: ignite
Updated Branches:
  refs/heads/master 9249efda5 -> 7a5aa7c6b


IGNITE-8664: Encoding categorical features with One-of-K Encoder

this closes #4106


Project: http://git-wip-us.apache.org/repos/asf/ignite/repo
Commit: http://git-wip-us.apache.org/repos/asf/ignite/commit/7a5aa7c6
Tree: http://git-wip-us.apache.org/repos/asf/ignite/tree/7a5aa7c6
Diff: http://git-wip-us.apache.org/repos/asf/ignite/diff/7a5aa7c6

Branch: refs/heads/master
Commit: 7a5aa7c6b91bbf8c6bbdabc232f3d09ac1b015f9
Parents: 9249efd
Author: zaleslaw <zaleslaw....@gmail.com>
Authored: Fri Jun 1 17:20:40 2018 +0300
Committer: Yury Babak <yba...@gridgain.com>
Committed: Fri Jun 1 17:20:40 2018 +0300

----------------------------------------------------------------------
 .../preprocessing/UnknownStringValue.java       |  35 +++++
 .../ml/preprocessing/encoding/package-info.java |  22 +++
 .../StringEncoderPartitionData.java             |  62 ++++++++
 .../StringEncoderPreprocessor.java              |  70 +++++++++
 .../stringencoder/StringEncoderTrainer.java     | 152 +++++++++++++++++++
 .../encoding/stringencoder/package-info.java    |  22 +++
 .../preprocessing/PreprocessingTestSuite.java   |   6 +-
 .../encoding/StringEncoderPreprocessorTest.java |  67 ++++++++
 .../encoding/StringEncoderTrainerTest.java      |  78 ++++++++++
 9 files changed, 513 insertions(+), 1 deletion(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/ignite/blob/7a5aa7c6/modules/ml/src/main/java/org/apache/ignite/ml/math/exceptions/preprocessing/UnknownStringValue.java
----------------------------------------------------------------------
diff --git 
a/modules/ml/src/main/java/org/apache/ignite/ml/math/exceptions/preprocessing/UnknownStringValue.java
 
b/modules/ml/src/main/java/org/apache/ignite/ml/math/exceptions/preprocessing/UnknownStringValue.java
new file mode 100644
index 0000000..f2312a1
--- /dev/null
+++ 
b/modules/ml/src/main/java/org/apache/ignite/ml/math/exceptions/preprocessing/UnknownStringValue.java
@@ -0,0 +1,35 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.math.exceptions.preprocessing;
+
+import org.apache.ignite.IgniteException;
+
+/**
+ * Indicates an unknown String value for StringEncoder.
+ */
+public class UnknownStringValue extends IgniteException {
+    /** */
+    private static final long serialVersionUID = 0L;
+
+    /**
+     * @param unknownString String value that caused this exception.
+     */
+    public UnknownStringValue(String unknownString) {
+        super("This String value is unknown for StringEncoder: " + 
unknownString);
+    }
+}

http://git-wip-us.apache.org/repos/asf/ignite/blob/7a5aa7c6/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/encoding/package-info.java
----------------------------------------------------------------------
diff --git 
a/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/encoding/package-info.java
 
b/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/encoding/package-info.java
new file mode 100644
index 0000000..436ad8f
--- /dev/null
+++ 
b/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/encoding/package-info.java
@@ -0,0 +1,22 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/**
+ * <!-- Package description. -->
+ * Contains encoding preprocessors.
+ */
+package org.apache.ignite.ml.preprocessing.encoding;

http://git-wip-us.apache.org/repos/asf/ignite/blob/7a5aa7c6/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/encoding/stringencoder/StringEncoderPartitionData.java
----------------------------------------------------------------------
diff --git 
a/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/encoding/stringencoder/StringEncoderPartitionData.java
 
b/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/encoding/stringencoder/StringEncoderPartitionData.java
new file mode 100644
index 0000000..acd2aa2
--- /dev/null
+++ 
b/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/encoding/stringencoder/StringEncoderPartitionData.java
@@ -0,0 +1,62 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.preprocessing.encoding.stringencoder;
+
+import java.util.Map;
+
+/**
+ * Partition data used in String Encoder preprocessor.
+ *
+ * @see StringEncoderTrainer
+ * @see StringEncoderPreprocessor
+ */
+public class StringEncoderPartitionData implements AutoCloseable {
+    /** Frequencies of categories for each categorial feature presented as 
strings. */
+    private Map<String, Integer>[] categoryFrequencies;
+
+    /**
+     * Constructs a new instance of String Encoder partition data.
+     */
+    public StringEncoderPartitionData() {
+    }
+
+    /**
+     * Gets the array of maps of frequencies by value in partition for each 
feature in the dataset.
+     *
+     * @return The frequencies.
+     */
+    public Map<String, Integer>[] categoryFrequencies() {
+        return categoryFrequencies;
+    }
+
+    /**
+     * Sets the array of maps of frequencies by value in partition for each 
feature in the dataset.
+     *
+     * @param categoryFrequencies The given value.
+     * @return The partition data.
+     */
+    public StringEncoderPartitionData withCategoryFrequencies(Map<String, 
Integer>[] categoryFrequencies) {
+        this.categoryFrequencies = categoryFrequencies;
+        return this;
+    }
+
+    /** */
+    @Override public void close() {
+        // Do nothing, GC will clean up.
+    }
+}

http://git-wip-us.apache.org/repos/asf/ignite/blob/7a5aa7c6/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/encoding/stringencoder/StringEncoderPreprocessor.java
----------------------------------------------------------------------
diff --git 
a/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/encoding/stringencoder/StringEncoderPreprocessor.java
 
b/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/encoding/stringencoder/StringEncoderPreprocessor.java
new file mode 100644
index 0000000..4b21e67
--- /dev/null
+++ 
b/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/encoding/stringencoder/StringEncoderPreprocessor.java
@@ -0,0 +1,70 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.preprocessing.encoding.stringencoder;
+
+import java.util.Map;
+import org.apache.ignite.ml.math.exceptions.preprocessing.UnknownStringValue;
+import org.apache.ignite.ml.math.functions.IgniteBiFunction;
+
+/**
+ * Preprocessing function that makes String encoding.
+ *
+ * @param <K> Type of a key in {@code upstream} data.
+ * @param <V> Type of a value in {@code upstream} data.
+ */
+public class StringEncoderPreprocessor<K, V> implements IgniteBiFunction<K, V, 
double[]> {
+    /** */
+    private static final long serialVersionUID = 6237812226382623469L;
+
+    /** Filling values. */
+    private final Map<String, Integer>[] encodingValues;
+
+    /** Base preprocessor. */
+    private final IgniteBiFunction<K, V, String[]> basePreprocessor;
+
+    /**
+     * Constructs a new instance of String Encoder preprocessor.
+     *
+     * @param basePreprocessor Base preprocessor.
+     */
+    public StringEncoderPreprocessor(Map<String, Integer>[] encodingValues,
+        IgniteBiFunction<K, V, String[]> basePreprocessor) {
+        this.encodingValues = encodingValues;
+        this.basePreprocessor = basePreprocessor;
+    }
+
+    /**
+     * Applies this preprocessor.
+     *
+     * @param k Key.
+     * @param v Value.
+     * @return Preprocessed row.
+     */
+    @Override public double[] apply(K k, V v) {
+        String[] tmp = basePreprocessor.apply(k, v);
+        double[] res = new double[tmp.length];
+
+        for (int i = 0; i < res.length; i++) {
+            if (encodingValues[i].containsKey(tmp[i]))
+                res[i] = encodingValues[i].get(tmp[i]);
+            else
+                throw new UnknownStringValue(tmp[i]);
+        }
+        return res;
+    }
+}

http://git-wip-us.apache.org/repos/asf/ignite/blob/7a5aa7c6/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/encoding/stringencoder/StringEncoderTrainer.java
----------------------------------------------------------------------
diff --git 
a/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/encoding/stringencoder/StringEncoderTrainer.java
 
b/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/encoding/stringencoder/StringEncoderTrainer.java
new file mode 100644
index 0000000..5a4d090
--- /dev/null
+++ 
b/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/encoding/stringencoder/StringEncoderTrainer.java
@@ -0,0 +1,152 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.preprocessing.encoding.stringencoder;
+
+import java.util.HashMap;
+import java.util.LinkedHashMap;
+import java.util.Map;
+import java.util.stream.Collectors;
+import org.apache.ignite.ml.dataset.Dataset;
+import org.apache.ignite.ml.dataset.DatasetBuilder;
+import org.apache.ignite.ml.dataset.UpstreamEntry;
+import org.apache.ignite.ml.dataset.primitive.context.EmptyContext;
+import org.apache.ignite.ml.math.functions.IgniteBiFunction;
+import org.apache.ignite.ml.preprocessing.PreprocessingTrainer;
+
+/**
+ * Trainer of the String Encoder preprocessor.
+ * The String Encoder encodes string values (categories) to double values in 
range [0.0, amountOfCategories)
+ * where the most popular value will be presented as 0.0 and the least popular 
value presented with amountOfCategories-1 value.
+ *
+ * @param <K> Type of a key in {@code upstream} data.
+ * @param <V> Type of a value in {@code upstream} data.
+ */
+public class StringEncoderTrainer<K, V> implements PreprocessingTrainer<K, V, 
String[], double[]> {
+    /** {@inheritDoc} */
+    @Override public StringEncoderPreprocessor<K, V> fit(DatasetBuilder<K, V> 
datasetBuilder,
+        IgniteBiFunction<K, V, String[]> basePreprocessor) {
+        try (Dataset<EmptyContext, StringEncoderPartitionData> dataset = 
datasetBuilder.build(
+            (upstream, upstreamSize) -> new EmptyContext(),
+            (upstream, upstreamSize, ctx) -> {
+                Map<String, Integer>[] categoryFrequencies = null;
+
+                while (upstream.hasNext()) {
+                    UpstreamEntry<K, V> entity = upstream.next();
+                    String[] row = basePreprocessor.apply(entity.getKey(), 
entity.getValue());
+                    categoryFrequencies = calculateFrequencies(row, 
categoryFrequencies);
+
+                }
+                return new StringEncoderPartitionData()
+                    .withCategoryFrequencies(categoryFrequencies);
+            }
+        )) {
+            Map<String, Integer>[] encodingValues = 
calculateEncodingValuesByFrequencies(dataset);
+
+            return new StringEncoderPreprocessor<>(encodingValues, 
basePreprocessor);
+        }
+        catch (Exception e) {
+            throw new RuntimeException(e);
+        }
+    }
+
+    /**
+     * Calculates the encoding values values by frequencies keeping in the 
given dataset.
+     *
+     * @param dataset The dataset of frequencies for each feature aggregated 
in each partition.
+     * @return Encoding values for each feature.
+     */
+    private Map<String, Integer>[] calculateEncodingValuesByFrequencies(
+        Dataset<EmptyContext, StringEncoderPartitionData> dataset) {
+        Map<String, Integer>[] frequencies = dataset.compute(
+            StringEncoderPartitionData::categoryFrequencies,
+            (a, b) -> {
+                if (a == null)
+                    return b;
+
+                if (b == null)
+                    return a;
+
+                assert a.length == b.length;
+
+                for (int i = 0; i < a.length; i++) {
+                    int finalI = i;
+                    a[i].forEach((k, v) -> b[finalI].merge(k, v, (f1, f2) -> 
f1 + f2));
+                }
+                return b;
+            }
+        );
+
+        Map<String, Integer>[] res = new HashMap[frequencies.length];
+
+        for (int i = 0; i < frequencies.length; i++)
+            res[i] = transformFrequenciesToEncodingValues(frequencies[i]);
+
+        return res;
+    }
+
+    /**
+     * Transforms frequencies to the encoding values.
+     *
+     * @param frequencies Frequencies of categories for the specific feature.
+     * @return Encoding values.
+     */
+    private Map<String, Integer> 
transformFrequenciesToEncodingValues(Map<String, Integer> frequencies) {
+        final HashMap<String, Integer> resMap = frequencies.entrySet()
+            .stream()
+            .sorted(Map.Entry.comparingByValue())
+            .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue,
+                (oldValue, newValue) -> oldValue, LinkedHashMap::new));
+
+        int amountOfLabels = frequencies.size();
+
+        for (Map.Entry<String, Integer> m : resMap.entrySet())
+            m.setValue(--amountOfLabels);
+
+        return resMap;
+    }
+
+    /**
+     * Updates frequencies by values and features.
+     *
+     * @param row Feature vector.
+     * @param categoryFrequencies Holds the frequencies of categories by 
values and features.
+     * @return Updated frequencies by values and features.
+     */
+    private Map<String, Integer>[] calculateFrequencies(String[] row, 
Map<String, Integer>[] categoryFrequencies) {
+        if (categoryFrequencies == null) {
+            categoryFrequencies = new HashMap[row.length];
+            for (int i = 0; i < categoryFrequencies.length; i++)
+                categoryFrequencies[i] = new HashMap<>();
+        }
+        else
+            assert categoryFrequencies.length == row.length : "Base 
preprocessor must return exactly " + categoryFrequencies.length
+                + " features";
+
+        for (int i = 0; i < categoryFrequencies.length; i++) {
+            String s = row[i];
+            Map<String, Integer> map = categoryFrequencies[i];
+
+            if (map.containsKey(s))
+                map.put(s, (map.get(s)) + 1);
+            else
+                map.put(s, 1);
+        }
+        return categoryFrequencies;
+    }
+
+}

http://git-wip-us.apache.org/repos/asf/ignite/blob/7a5aa7c6/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/encoding/stringencoder/package-info.java
----------------------------------------------------------------------
diff --git 
a/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/encoding/stringencoder/package-info.java
 
b/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/encoding/stringencoder/package-info.java
new file mode 100644
index 0000000..7cdb40f
--- /dev/null
+++ 
b/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/encoding/stringencoder/package-info.java
@@ -0,0 +1,22 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/**
+ * <!-- Package description. -->
+ * Contains string encoding preprocessor.
+ */
+package org.apache.ignite.ml.preprocessing.encoding.stringencoder;

http://git-wip-us.apache.org/repos/asf/ignite/blob/7a5aa7c6/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/PreprocessingTestSuite.java
----------------------------------------------------------------------
diff --git 
a/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/PreprocessingTestSuite.java
 
b/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/PreprocessingTestSuite.java
index f0c566c..cb29ecb 100644
--- 
a/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/PreprocessingTestSuite.java
+++ 
b/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/PreprocessingTestSuite.java
@@ -19,6 +19,8 @@ package org.apache.ignite.ml.preprocessing;
 
 import 
org.apache.ignite.ml.preprocessing.binarization.BinarizationPreprocessorTest;
 import org.apache.ignite.ml.preprocessing.binarization.BinarizationTrainerTest;
+import 
org.apache.ignite.ml.preprocessing.encoding.StringEncoderPreprocessorTest;
+import org.apache.ignite.ml.preprocessing.encoding.StringEncoderTrainerTest;
 import org.apache.ignite.ml.preprocessing.imputing.ImputerPreprocessorTest;
 import org.apache.ignite.ml.preprocessing.imputing.ImputerTrainerTest;
 import 
org.apache.ignite.ml.preprocessing.normalization.NormalizationPreprocessorTest;
@@ -36,7 +38,9 @@ import org.junit.runners.Suite;
     BinarizationPreprocessorTest.class,
     BinarizationTrainerTest.class,
     ImputerPreprocessorTest.class,
-    ImputerTrainerTest.class
+    ImputerTrainerTest.class,
+    StringEncoderTrainerTest.class,
+    StringEncoderPreprocessorTest.class
 })
 public class PreprocessingTestSuite {
     // No-op.

http://git-wip-us.apache.org/repos/asf/ignite/blob/7a5aa7c6/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/encoding/StringEncoderPreprocessorTest.java
----------------------------------------------------------------------
diff --git 
a/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/encoding/StringEncoderPreprocessorTest.java
 
b/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/encoding/StringEncoderPreprocessorTest.java
new file mode 100644
index 0000000..d74b923
--- /dev/null
+++ 
b/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/encoding/StringEncoderPreprocessorTest.java
@@ -0,0 +1,67 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.preprocessing.encoding;
+
+import java.util.HashMap;
+import 
org.apache.ignite.ml.preprocessing.encoding.stringencoder.StringEncoderPreprocessor;
+import org.junit.Test;
+
+import static org.junit.Assert.assertArrayEquals;
+
+/**
+ * Tests for {@link StringEncoderPreprocessor}.
+ */
+public class StringEncoderPreprocessorTest {
+    /** Tests {@code apply()} method. */
+    @Test
+    public void testApply() {
+        String[][] data = new String[][]{
+            {"1", "Moscow", "A"},
+            {"2", "Moscow", "B"},
+            {"2", "Moscow", "B"},
+        };
+
+        StringEncoderPreprocessor<Integer, String[]> preprocessor = new 
StringEncoderPreprocessor<Integer, String[]>(
+            new HashMap[]{new HashMap() {
+                {
+                    put("1", 1);
+                    put("2", 0);
+                }
+            }, new HashMap() {
+                {
+                    put("Moscow", 0);
+                }
+            }, new HashMap() {
+                {
+                    put("A", 1);
+                    put("B", 0);
+                }
+            }},
+            (k, v) -> v
+        );
+
+        double[][] postProcessedData = new double[][]{
+            {1.0, 0.0, 1.0},
+            {0.0, 0.0, 0.0},
+            {0.0, 0.0, 0.0},
+        };
+
+       for (int i = 0; i < data.length; i++)
+           assertArrayEquals(postProcessedData[i], preprocessor.apply(i, 
data[i]), 1e-8);
+    }
+}

http://git-wip-us.apache.org/repos/asf/ignite/blob/7a5aa7c6/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/encoding/StringEncoderTrainerTest.java
----------------------------------------------------------------------
diff --git 
a/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/encoding/StringEncoderTrainerTest.java
 
b/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/encoding/StringEncoderTrainerTest.java
new file mode 100644
index 0000000..aa17beb
--- /dev/null
+++ 
b/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/encoding/StringEncoderTrainerTest.java
@@ -0,0 +1,78 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.preprocessing.encoding;
+
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.Map;
+import org.apache.ignite.ml.dataset.DatasetBuilder;
+import org.apache.ignite.ml.dataset.impl.local.LocalDatasetBuilder;
+import 
org.apache.ignite.ml.preprocessing.encoding.stringencoder.StringEncoderPreprocessor;
+import 
org.apache.ignite.ml.preprocessing.encoding.stringencoder.StringEncoderTrainer;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+
+import static org.junit.Assert.assertArrayEquals;
+
+/**
+ * Tests for {@link StringEncoderTrainer}.
+ */
+@RunWith(Parameterized.class)
+public class StringEncoderTrainerTest {
+    /** Parameters. */
+    @Parameterized.Parameters(name = "Data divided on {0} partitions")
+    public static Iterable<Integer[]> data() {
+        return Arrays.asList(
+            new Integer[] {1},
+            new Integer[] {2},
+            new Integer[] {3},
+            new Integer[] {5},
+            new Integer[] {7},
+            new Integer[] {100},
+            new Integer[] {1000}
+        );
+    }
+
+    /** Number of partitions. */
+    @Parameterized.Parameter
+    public int parts;
+
+    /** Tests {@code fit()} method. */
+    @Test
+    public void testFit() {
+        Map<Integer, String[]> data = new HashMap<>();
+        data.put(1, new String[] {"Monday", "September"});
+        data.put(2, new String[] {"Monday", "August"});
+        data.put(3, new String[] {"Monday", "August"});
+        data.put(4, new String[] {"Friday", "June"});
+        data.put(5, new String[] {"Friday", "June"});
+        data.put(6, new String[] {"Sunday", "August"});
+
+        DatasetBuilder<Integer, String[]> datasetBuilder = new 
LocalDatasetBuilder<>(data, parts);
+
+        StringEncoderTrainer<Integer, String[]> strEncoderTrainer = new 
StringEncoderTrainer<>();
+
+        StringEncoderPreprocessor<Integer, String[]> preprocessor = 
strEncoderTrainer.fit(
+            datasetBuilder,
+            (k, v) -> v
+        );
+
+        assertArrayEquals(new double[] {0.0, 2.0}, preprocessor.apply(7, new 
String[] {"Monday", "September"}), 1e-8);
+    }
+}

Reply via email to