This is an automated email from the ASF dual-hosted git repository. zaleslaw pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/ignite.git
The following commit(s) were added to refs/heads/master by this push: new 1bf68b0 IGNITE-11655: [ML] OneHotEncoder returns more columns than expected (#6376) 1bf68b0 is described below commit 1bf68b06a7658aa6dc1a4f3327311cc08194047b Author: Alexey Zinoviev <zaleslaw....@gmail.com> AuthorDate: Fri Mar 29 18:26:49 2019 +0300 IGNITE-11655: [ML] OneHotEncoder returns more columns than expected (#6376) --- .../onehotencoder/OneHotEncoderPreprocessor.java | 21 ++--- .../preprocessing/encoding/EncoderTrainerTest.java | 6 +- .../encoding/OneHotEncoderPreprocessorTest.java | 94 ++++++++++++++++++++-- 3 files changed, 101 insertions(+), 20 deletions(-) diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/encoding/onehotencoder/OneHotEncoderPreprocessor.java b/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/encoding/onehotencoder/OneHotEncoderPreprocessor.java index 7aadadf..96479be 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/encoding/onehotencoder/OneHotEncoderPreprocessor.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/encoding/onehotencoder/OneHotEncoderPreprocessor.java @@ -36,7 +36,7 @@ import org.apache.ignite.ml.preprocessing.encoding.EncoderTrainer; * * This preprocessor can transform multiple columns which indices are handled during training process. * - * Each one-hot encoded binary vector adds its cells to the end of the current feature vector. + * Each one-hot encoded binary vector adds its cells to the end of the current feature vector according the order of handled categorial features. * * @param <K> Type of a key in {@code upstream} data. * @param <V> Type of a value in {@code upstream} data. @@ -70,38 +70,39 @@ public class OneHotEncoderPreprocessor<K, V> extends EncoderPreprocessor<K, V> { */ @Override public Vector apply(K k, V v) { Object[] tmp = basePreprocessor.apply(k, v); + int amountOfCategorialFeatures = handledIndices.size(); - double[] res = new double[tmp.length + getAdditionalSize(encodingValues)]; + double[] res = new double[tmp.length - amountOfCategorialFeatures + getAdditionalSize(encodingValues)]; int categorialFeatureCntr = 0; + int resIdx = 0; for (int i = 0; i < tmp.length; i++) { Object tmpObj = tmp[i]; + if (handledIndices.contains(i)) { categorialFeatureCntr++; if (tmpObj.equals(Double.NaN) && encodingValues[i].containsKey(KEY_FOR_NULL_VALUES)) { final Integer indexedVal = encodingValues[i].get(KEY_FOR_NULL_VALUES); - res[i] = indexedVal; - - res[tmp.length + getIdxOffset(categorialFeatureCntr, indexedVal, encodingValues)] = 1.0; + res[tmp.length - amountOfCategorialFeatures + getIdxOffset(categorialFeatureCntr, indexedVal, encodingValues)] = 1.0; } else { final String key = String.valueOf(tmpObj); if (encodingValues[i].containsKey(key)) { final Integer indexedVal = encodingValues[i].get(key); - res[i] = indexedVal; - - res[tmp.length + getIdxOffset(categorialFeatureCntr, indexedVal, encodingValues)] = 1.0; + res[tmp.length - amountOfCategorialFeatures + getIdxOffset(categorialFeatureCntr, indexedVal, encodingValues)] = 1.0; } else throw new UnknownCategorialFeatureValue(tmpObj.toString()); } - } else - res[i] = (double) tmpObj; + } else { + res[resIdx] = (double) tmpObj; + resIdx++; + } } return VectorUtils.of(res); } diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/encoding/EncoderTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/encoding/EncoderTrainerTest.java index 1bf69e5..bee715c 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/encoding/EncoderTrainerTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/encoding/EncoderTrainerTest.java @@ -26,8 +26,8 @@ import org.apache.ignite.ml.dataset.impl.local.LocalDatasetBuilder; import org.apache.ignite.ml.math.exceptions.preprocessing.UnknownCategorialFeatureValue; import org.junit.Test; -import static org.junit.Assert.fail; import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.fail; /** * Tests for {@link EncoderTrainer}. @@ -83,8 +83,8 @@ public class EncoderTrainerTest extends TrainerTest { datasetBuilder, (k, v) -> v ); - assertArrayEquals(new double[]{0.0, 2.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0}, preprocessor.apply(7, new Double[]{3.0, 0.0}).asArray(), 1e-8); - assertArrayEquals(new double[]{1.0, 0.0, 0.0, 1.0, 0.0, 1.0, 0.0, 0.0}, preprocessor.apply(8, new Double[]{2.0, 12.0}).asArray(), 1e-8); + assertArrayEquals(new double[]{1.0, 0.0, 0.0, 0.0, 0.0, 1.0}, preprocessor.apply(7, new Double[]{3.0, 0.0}).asArray(), 1e-8); + assertArrayEquals(new double[]{0.0, 1.0, 0.0, 1.0, 0.0, 0.0}, preprocessor.apply(8, new Double[]{2.0, 12.0}).asArray(), 1e-8); } /** Tests {@code fit()} method. */ diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/encoding/OneHotEncoderPreprocessorTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/encoding/OneHotEncoderPreprocessorTest.java index 5af7335..b60b0ff 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/encoding/OneHotEncoderPreprocessorTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/encoding/OneHotEncoderPreprocessorTest.java @@ -24,8 +24,8 @@ import org.apache.ignite.ml.preprocessing.encoding.onehotencoder.OneHotEncoderPr import org.apache.ignite.ml.preprocessing.encoding.stringencoder.StringEncoderPreprocessor; import org.junit.Test; -import static org.junit.Assert.fail; import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.fail; /** * Tests for {@link StringEncoderPreprocessor}. @@ -66,9 +66,9 @@ public class OneHotEncoderPreprocessorTest { }); double[][] postProcessedData = new double[][]{ - {1.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 0.0}, - {0.0, 0.0, 0.0, 1.0, 0.0, 1.0, 1.0, 0.0}, - {0.0, 0.0, 1.0, 1.0, 0.0, 1.0, 0.0, 1.0}, + {0.0, 1.0, 1.0, 1.0, 0.0}, + {1.0, 0.0, 1.0, 1.0, 0.0}, + {1.0, 0.0, 1.0, 0.0, 1.0}, }; for (int i = 0; i < data.length; i++) @@ -76,6 +76,86 @@ public class OneHotEncoderPreprocessorTest { } + /** */ + @Test + public void testOneCategorialFeature() { + String[][] data = new String[][]{ + {"42"}, + {"43"}, + {"42"}, + }; + + OneHotEncoderPreprocessor<Integer, String[]> preprocessor = new OneHotEncoderPreprocessor<Integer, String[]>( + new HashMap[]{new HashMap() { + { + put("42", 0); + put("43", 1); + } + }}, + (k, v) -> v, + new HashSet() { + { + add(0); + } + }); + + double[][] postProcessedData = new double[][]{ + {1.0, 0.0}, + {0.0, 1.0}, + {1.0, 0.0}, + }; + + for (int i = 0; i < data.length; i++) + assertArrayEquals(postProcessedData[i], preprocessor.apply(i, data[i]).asArray(), 1e-8); + } + + /** */ + @Test + public void testTwoCategorialFeatureAndTwoDoubleFeatures() { + Object[][] data = new Object[][]{ + {"42", 1.0, "M", 2.0}, + {"43", 2.0, "F", 3.0}, + {"42", 3.0, Double.NaN, 4.0}, + {"42", 4.0, "F", 5.0}, + }; + + HashMap[] encodingValues = new HashMap[4]; + encodingValues[0] = new HashMap() { + { + put("42", 0); + put("43", 1); + } + }; + + encodingValues[2] = new HashMap() { + { + put("F", 0); + put("M", 1); + put("", 2); + } + }; + + OneHotEncoderPreprocessor<Integer, Object[]> preprocessor = new OneHotEncoderPreprocessor<Integer, Object[]>( + encodingValues, + (k, v) -> v, + new HashSet() { + { + add(0); + add(2); + } + }); + + double[][] postProcessedData = new double[][]{ + {1.0, 2.0, 1.0, 0.0, 0.0, 1.0, 0.0}, + {2.0, 3.0, 0.0, 1.0, 1.0, 0.0, 0.0}, + {3.0, 4.0, 1.0, 0.0, 0.0, 0.0, 1.0}, + {4.0, 5.0, 1.0, 0.0, 1.0, 0.0, 0.0}, + }; + + for (int i = 0; i < data.length; i++) + assertArrayEquals(postProcessedData[i], preprocessor.apply(i, data[i]).asArray(), 1e-8); + } + /** * The {@code apply()} method is failed with UnknownCategorialFeatureValue exception. * @@ -116,9 +196,9 @@ public class OneHotEncoderPreprocessorTest { }); double[][] postProcessedData = new double[][]{ - {1.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 0.0}, - {0.0, 0.0, 0.0, 1.0, 0.0, 1.0, 1.0, 0.0}, - {0.0, 0.0, 1.0, 1.0, 0.0, 1.0, 0.0, 1.0}, + {0.0, 1.0, 1.0, 1.0, 0.0}, + {1.0, 0.0, 1.0, 1.0, 0.0}, + {1.0, 0.0, 1.0, 0.0, 1.0}, }; try {