This is an automated email from the ASF dual-hosted git repository. myui pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/incubator-hivemall.git
The following commit(s) were added to refs/heads/master by this push: new 72dca39 Added sanity checks for training data in RandomForest 72dca39 is described below commit 72dca396c6851c9ea44df7eac86ba677ea21879e Author: Makoto Yui <m...@apache.org> AuthorDate: Wed Jul 10 16:17:20 2019 +0900 Added sanity checks for training data in RandomForest --- .../classification/RandomForestClassifierUDTF.java | 10 ++ .../RandomForestClassifierUDTFTest.java | 101 ++++++++++++++++++++- 2 files changed, 108 insertions(+), 3 deletions(-) diff --git a/core/src/main/java/hivemall/smile/classification/RandomForestClassifierUDTF.java b/core/src/main/java/hivemall/smile/classification/RandomForestClassifierUDTF.java index 7f2966b..99396b7 100644 --- a/core/src/main/java/hivemall/smile/classification/RandomForestClassifierUDTF.java +++ b/core/src/main/java/hivemall/smile/classification/RandomForestClassifierUDTF.java @@ -327,6 +327,16 @@ public final class RandomForestClassifierUDTF extends UDTFWithOptions { int[] y = labels.toArray(); this.labels = null; + // sanity checks + if (x.numColumns() == 0) { + throw new HiveException( + "No non-null features in the training examples. Revise training data"); + } + if (x.numRows() != y.length) { + throw new HiveException("Illegal condition was met. y.length=" + y.length + + ", X.length=" + x.numRows()); + } + // run training train(x, y); } diff --git a/core/src/test/java/hivemall/smile/classification/RandomForestClassifierUDTFTest.java b/core/src/test/java/hivemall/smile/classification/RandomForestClassifierUDTFTest.java index 0793ae6..aa839fa 100644 --- a/core/src/test/java/hivemall/smile/classification/RandomForestClassifierUDTFTest.java +++ b/core/src/test/java/hivemall/smile/classification/RandomForestClassifierUDTFTest.java @@ -22,6 +22,8 @@ import hivemall.TestUtils; import hivemall.classifier.KernelExpansionPassiveAggressiveUDTF; import hivemall.utils.codec.Base91; import hivemall.utils.lang.mutable.MutableInt; +import smile.data.AttributeDataset; +import smile.data.parser.ArffParser; import java.io.BufferedInputStream; import java.io.BufferedReader; @@ -32,6 +34,7 @@ import java.net.URL; import java.text.ParseException; import java.util.ArrayList; import java.util.List; +import java.util.Random; import java.util.StringTokenizer; import java.util.zip.GZIPInputStream; @@ -48,9 +51,6 @@ import org.apache.hadoop.io.Text; import org.junit.Assert; import org.junit.Test; -import smile.data.AttributeDataset; -import smile.data.parser.ArffParser; - public class RandomForestClassifierUDTFTest { @Test @@ -98,6 +98,101 @@ public class RandomForestClassifierUDTFTest { } @Test + public void testIrisDenseSomeNullFeaturesTest() + throws IOException, ParseException, HiveException { + URL url = new URL( + "https://gist.githubusercontent.com/myui/143fa9d05bd6e7db0114/raw/500f178316b802f1cade6e3bf8dc814a96e84b1e/iris.arff"); + InputStream is = new BufferedInputStream(url.openStream()); + + ArffParser arffParser = new ArffParser(); + arffParser.setResponseIndex(4); + + AttributeDataset iris = arffParser.parse(is); + int size = iris.size(); + double[][] x = iris.toArray(new double[size][]); + int[] y = iris.toArray(new int[size]); + + RandomForestClassifierUDTF udtf = new RandomForestClassifierUDTF(); + ObjectInspector param = ObjectInspectorUtils.getConstantObjectInspector( + PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-trees 49"); + udtf.initialize(new ObjectInspector[] { + ObjectInspectorFactory.getStandardListObjectInspector( + PrimitiveObjectInspectorFactory.javaDoubleObjectInspector), + PrimitiveObjectInspectorFactory.javaIntObjectInspector, param}); + + final Random rand = new Random(43); + final List<Double> xi = new ArrayList<Double>(x[0].length); + for (int i = 0; i < size; i++) { + for (int j = 0; j < x[i].length; j++) { + if (rand.nextDouble() >= 0.7) { + xi.add(j, null); + } else { + xi.add(j, x[i][j]); + } + } + udtf.process(new Object[] {xi, y[i]}); + xi.clear(); + } + + final MutableInt count = new MutableInt(0); + Collector collector = new Collector() { + public void collect(Object input) throws HiveException { + count.addValue(1); + } + }; + + udtf.setCollector(collector); + udtf.close(); + + Assert.assertEquals(49, count.getValue()); + } + + @Test(expected = HiveException.class) + public void testIrisDenseAllNullFeaturesTest() + throws IOException, ParseException, HiveException { + URL url = new URL( + "https://gist.githubusercontent.com/myui/143fa9d05bd6e7db0114/raw/500f178316b802f1cade6e3bf8dc814a96e84b1e/iris.arff"); + InputStream is = new BufferedInputStream(url.openStream()); + + ArffParser arffParser = new ArffParser(); + arffParser.setResponseIndex(4); + + AttributeDataset iris = arffParser.parse(is); + int size = iris.size(); + double[][] x = iris.toArray(new double[size][]); + int[] y = iris.toArray(new int[size]); + + RandomForestClassifierUDTF udtf = new RandomForestClassifierUDTF(); + ObjectInspector param = ObjectInspectorUtils.getConstantObjectInspector( + PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-trees 49"); + udtf.initialize(new ObjectInspector[] { + ObjectInspectorFactory.getStandardListObjectInspector( + PrimitiveObjectInspectorFactory.javaDoubleObjectInspector), + PrimitiveObjectInspectorFactory.javaIntObjectInspector, param}); + + final List<Double> xi = new ArrayList<Double>(x[0].length); + for (int i = 0; i < size; i++) { + for (int j = 0; j < x[i].length; j++) { + xi.add(j, null); + } + udtf.process(new Object[] {xi, y[i]}); + xi.clear(); + } + + final MutableInt count = new MutableInt(0); + Collector collector = new Collector() { + public void collect(Object input) throws HiveException { + count.addValue(1); + } + }; + + udtf.setCollector(collector); + udtf.close(); + + Assert.fail("should not be called"); + } + + @Test public void testIrisSparse() throws IOException, ParseException, HiveException { URL url = new URL( "https://gist.githubusercontent.com/myui/143fa9d05bd6e7db0114/raw/500f178316b802f1cade6e3bf8dc814a96e84b1e/iris.arff");