http://git-wip-us.apache.org/repos/asf/ignite/blob/9abfee69/modules/ml/src/test/java/org/apache/ignite/ml/tree/impurity/mse/MSEImpurityMeasureTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/tree/impurity/mse/MSEImpurityMeasureTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/tree/impurity/mse/MSEImpurityMeasureTest.java new file mode 100644 index 0000000..3d11d9d --- /dev/null +++ b/modules/ml/src/test/java/org/apache/ignite/ml/tree/impurity/mse/MSEImpurityMeasureTest.java @@ -0,0 +1,109 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.tree.impurity.mse; + +import java.util.Random; +import org.junit.Test; + +import static junit.framework.TestCase.assertEquals; + +/** + * Tests for {@link MSEImpurityMeasure}. + */ +public class MSEImpurityMeasureTest { + /** */ + @Test + public void testImpurityOnEmptyData() { + MSEImpurityMeasure impurity = new MSEImpurityMeasure(0, 0, 0, 0, 0, 0); + + assertEquals(0.0, impurity.impurity(), 1e-10); + } + + /** */ + @Test + public void testImpurityLeftPart() { + // Test on left part [1, 2, 2, 1, 1, 1]. + MSEImpurityMeasure impurity = new MSEImpurityMeasure(8, 12, 6, 0, 0, 0); + + assertEquals(1.333, impurity.impurity(), 1e-3); + } + + /** */ + @Test + public void testImpurityRightPart() { + // Test on right part [1, 2, 2, 1, 1, 1]. + MSEImpurityMeasure impurity = new MSEImpurityMeasure(0, 0, 0, 8, 12, 6); + + assertEquals(1.333, impurity.impurity(), 1e-3); + } + + /** */ + @Test + public void testImpurityLeftAndRightPart() { + // Test on left part [1, 2, 2] and right part [1, 1, 1]. + MSEImpurityMeasure impurity = new MSEImpurityMeasure(5, 9, 3, 3, 3, 3); + + assertEquals(0.666, impurity.impurity(), 1e-3); + } + + /** */ + @Test + public void testAdd() { + Random rnd = new Random(0); + + MSEImpurityMeasure a = new MSEImpurityMeasure( + rnd.nextDouble(), rnd.nextDouble(), rnd.nextInt(), rnd.nextDouble(), rnd.nextDouble(), rnd.nextInt() + ); + + MSEImpurityMeasure b = new MSEImpurityMeasure( + rnd.nextDouble(), rnd.nextDouble(), rnd.nextInt(), rnd.nextDouble(), rnd.nextDouble(), rnd.nextInt() + ); + + MSEImpurityMeasure c = a.add(b); + + assertEquals(a.getLeftY() + b.getLeftY(), c.getLeftY(), 1e-10); + assertEquals(a.getLeftY2() + b.getLeftY2(), c.getLeftY2(), 1e-10); + assertEquals(a.getLeftCnt() + b.getLeftCnt(), c.getLeftCnt()); + assertEquals(a.getRightY() + b.getRightY(), c.getRightY(), 1e-10); + assertEquals(a.getRightY2() + b.getRightY2(), c.getRightY2(), 1e-10); + assertEquals(a.getRightCnt() + b.getRightCnt(), c.getRightCnt()); + } + + /** */ + @Test + public void testSubtract() { + Random rnd = new Random(0); + + MSEImpurityMeasure a = new MSEImpurityMeasure( + rnd.nextDouble(), rnd.nextDouble(), rnd.nextInt(), rnd.nextDouble(), rnd.nextDouble(), rnd.nextInt() + ); + + MSEImpurityMeasure b = new MSEImpurityMeasure( + rnd.nextDouble(), rnd.nextDouble(), rnd.nextInt(), rnd.nextDouble(), rnd.nextDouble(), rnd.nextInt() + ); + + MSEImpurityMeasure c = a.subtract(b); + + assertEquals(a.getLeftY() - b.getLeftY(), c.getLeftY(), 1e-10); + assertEquals(a.getLeftY2() - b.getLeftY2(), c.getLeftY2(), 1e-10); + assertEquals(a.getLeftCnt() - b.getLeftCnt(), c.getLeftCnt()); + assertEquals(a.getRightY() - b.getRightY(), c.getRightY(), 1e-10); + assertEquals(a.getRightY2() - b.getRightY2(), c.getRightY2(), 1e-10); + assertEquals(a.getRightCnt() - b.getRightCnt(), c.getRightCnt()); + } +}
http://git-wip-us.apache.org/repos/asf/ignite/blob/9abfee69/modules/ml/src/test/java/org/apache/ignite/ml/tree/impurity/util/SimpleStepFunctionCompressorTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/tree/impurity/util/SimpleStepFunctionCompressorTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/tree/impurity/util/SimpleStepFunctionCompressorTest.java new file mode 100644 index 0000000..001404f --- /dev/null +++ b/modules/ml/src/test/java/org/apache/ignite/ml/tree/impurity/util/SimpleStepFunctionCompressorTest.java @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.tree.impurity.util; + +import org.junit.Test; + +import static org.junit.Assert.assertArrayEquals; + +/** + * Tests for {@link SimpleStepFunctionCompressor}. + */ +public class SimpleStepFunctionCompressorTest { + /** */ + @Test + public void testCompressSmallFunction() { + StepFunction<TestImpurityMeasure> function = new StepFunction<>( + new double[]{1, 2, 3, 4}, + TestImpurityMeasure.asTestImpurityMeasures(1, 2, 3, 4) + ); + + SimpleStepFunctionCompressor<TestImpurityMeasure> compressor = new SimpleStepFunctionCompressor<>(5, 0, 0); + + StepFunction<TestImpurityMeasure> resFunction = compressor.compress(function); + + assertArrayEquals(new double[]{1, 2, 3, 4}, resFunction.getX(), 1e-10); + assertArrayEquals(TestImpurityMeasure.asTestImpurityMeasures(1, 2, 3, 4), resFunction.getY()); + } + + /** */ + @Test + public void testCompressIncreasingFunction() { + StepFunction<TestImpurityMeasure> function = new StepFunction<>( + new double[]{1, 2, 3, 4, 5}, + TestImpurityMeasure.asTestImpurityMeasures(1, 2, 3, 4, 5) + ); + + SimpleStepFunctionCompressor<TestImpurityMeasure> compressor = new SimpleStepFunctionCompressor<>(1, 0.4, 0); + + StepFunction<TestImpurityMeasure> resFunction = compressor.compress(function); + + assertArrayEquals(new double[]{1, 3, 5}, resFunction.getX(), 1e-10); + assertArrayEquals(TestImpurityMeasure.asTestImpurityMeasures(1, 3, 5), resFunction.getY()); + } + + /** */ + @Test + public void testCompressDecreasingFunction() { + StepFunction<TestImpurityMeasure> function = new StepFunction<>( + new double[]{1, 2, 3, 4, 5}, + TestImpurityMeasure.asTestImpurityMeasures(5, 4, 3, 2, 1) + ); + + SimpleStepFunctionCompressor<TestImpurityMeasure> compressor = new SimpleStepFunctionCompressor<>(1, 0, 0.4); + + StepFunction<TestImpurityMeasure> resFunction = compressor.compress(function); + + assertArrayEquals(new double[]{1, 3, 5}, resFunction.getX(), 1e-10); + assertArrayEquals(TestImpurityMeasure.asTestImpurityMeasures(5, 3, 1), resFunction.getY()); + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/9abfee69/modules/ml/src/test/java/org/apache/ignite/ml/tree/impurity/util/StepFunctionTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/tree/impurity/util/StepFunctionTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/tree/impurity/util/StepFunctionTest.java new file mode 100644 index 0000000..2a0279c --- /dev/null +++ b/modules/ml/src/test/java/org/apache/ignite/ml/tree/impurity/util/StepFunctionTest.java @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.tree.impurity.util; + +import org.junit.Test; + +import static org.junit.Assert.assertArrayEquals; + +/** + * Tests for {@link StepFunction}. + */ +public class StepFunctionTest { + /** */ + @Test + public void testAddIncreasingFunctions() { + StepFunction<TestImpurityMeasure> a = new StepFunction<>( + new double[]{1, 3, 5}, + TestImpurityMeasure.asTestImpurityMeasures(1, 2, 3) + ); + + StepFunction<TestImpurityMeasure> b = new StepFunction<>( + new double[]{0, 2, 4}, + TestImpurityMeasure.asTestImpurityMeasures(1, 2, 3) + ); + + StepFunction<TestImpurityMeasure> c = a.add(b); + + assertArrayEquals(new double[]{0, 1, 2, 3, 4, 5}, c.getX(), 1e-10); + assertArrayEquals( + TestImpurityMeasure.asTestImpurityMeasures(1, 2, 3, 4, 5, 6), + c.getY() + ); + } + + /** */ + @Test + public void testAddDecreasingFunctions() { + StepFunction<TestImpurityMeasure> a = new StepFunction<>( + new double[]{1, 3, 5}, + TestImpurityMeasure.asTestImpurityMeasures(3, 2, 1) + ); + + StepFunction<TestImpurityMeasure> b = new StepFunction<>( + new double[]{0, 2, 4}, + TestImpurityMeasure.asTestImpurityMeasures(3, 2, 1) + ); + + StepFunction<TestImpurityMeasure> c = a.add(b); + + assertArrayEquals(new double[]{0, 1, 2, 3, 4, 5}, c.getX(), 1e-10); + assertArrayEquals( + TestImpurityMeasure.asTestImpurityMeasures(3, 6, 5, 4, 3, 2), + c.getY() + ); + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/9abfee69/modules/ml/src/test/java/org/apache/ignite/ml/tree/impurity/util/TestImpurityMeasure.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/tree/impurity/util/TestImpurityMeasure.java b/modules/ml/src/test/java/org/apache/ignite/ml/tree/impurity/util/TestImpurityMeasure.java new file mode 100644 index 0000000..c0d1911 --- /dev/null +++ b/modules/ml/src/test/java/org/apache/ignite/ml/tree/impurity/util/TestImpurityMeasure.java @@ -0,0 +1,88 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.tree.impurity.util; + +import java.util.Objects; +import org.apache.ignite.ml.tree.impurity.ImpurityMeasure; + +/** + * Utils class used as impurity measure in tests. + */ +class TestImpurityMeasure implements ImpurityMeasure<TestImpurityMeasure> { + /** */ + private static final long serialVersionUID = 2414020770162797847L; + + /** Impurity. */ + private final double impurity; + + /** + * Constructs a new instance of test impurity measure. + * + * @param impurity Impurity. + */ + private TestImpurityMeasure(double impurity) { + this.impurity = impurity; + } + + /** + * Convert doubles to array of test impurity measures. + * + * @param impurity Impurity as array of doubles. + * @return Test impurity measure objects as array. + */ + static TestImpurityMeasure[] asTestImpurityMeasures(double... impurity) { + TestImpurityMeasure[] res = new TestImpurityMeasure[impurity.length]; + + for (int i = 0; i < impurity.length; i++) + res[i] = new TestImpurityMeasure(impurity[i]); + + return res; + } + + /** {@inheritDoc} */ + @Override public double impurity() { + return impurity; + } + + /** {@inheritDoc} */ + @Override public TestImpurityMeasure add(TestImpurityMeasure measure) { + return new TestImpurityMeasure(impurity + measure.impurity); + } + + /** {@inheritDoc} */ + @Override public TestImpurityMeasure subtract(TestImpurityMeasure measure) { + return new TestImpurityMeasure(impurity - measure.impurity); + } + + /** */ + @Override public boolean equals(Object o) { + if (this == o) + return true; + if (o == null || getClass() != o.getClass()) + return false; + TestImpurityMeasure measure = (TestImpurityMeasure)o; + + return Double.compare(measure.impurity, impurity) == 0; + } + + /** */ + @Override public int hashCode() { + + return Objects.hash(impurity); + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/9abfee69/modules/ml/src/test/java/org/apache/ignite/ml/tree/performance/DecisionTreeMNISTIntegrationTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/tree/performance/DecisionTreeMNISTIntegrationTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/tree/performance/DecisionTreeMNISTIntegrationTest.java new file mode 100644 index 0000000..b259ec9 --- /dev/null +++ b/modules/ml/src/test/java/org/apache/ignite/ml/tree/performance/DecisionTreeMNISTIntegrationTest.java @@ -0,0 +1,105 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.tree.performance; + +import java.io.IOException; +import org.apache.ignite.Ignite; +import org.apache.ignite.IgniteCache; +import org.apache.ignite.cache.affinity.rendezvous.RendezvousAffinityFunction; +import org.apache.ignite.configuration.CacheConfiguration; +import org.apache.ignite.internal.util.IgniteUtils; +import org.apache.ignite.ml.dataset.impl.cache.CacheBasedDatasetBuilder; +import org.apache.ignite.ml.nn.performance.MnistMLPTestUtil; +import org.apache.ignite.ml.tree.DecisionTreeClassificationTrainer; +import org.apache.ignite.ml.tree.DecisionTreeNode; +import org.apache.ignite.ml.tree.impurity.util.SimpleStepFunctionCompressor; +import org.apache.ignite.ml.util.MnistUtils; +import org.apache.ignite.testframework.junits.common.GridCommonAbstractTest; + +/** + * Tests {@link DecisionTreeClassificationTrainer} on the MNIST dataset that require to start the whole Ignite + * infrastructure. For manual run. + */ +public class DecisionTreeMNISTIntegrationTest extends GridCommonAbstractTest { + /** Number of nodes in grid */ + private static final int NODE_COUNT = 3; + + /** Ignite instance. */ + private Ignite ignite; + + /** {@inheritDoc} */ + @Override protected void beforeTestsStarted() throws Exception { + for (int i = 1; i <= NODE_COUNT; i++) + startGrid(i); + } + + /** {@inheritDoc} */ + @Override protected void afterTestsStopped() { + stopAllGrids(); + } + + /** + * {@inheritDoc} + */ + @Override protected void beforeTest() throws Exception { + /* Grid instance. */ + ignite = grid(NODE_COUNT); + ignite.configuration().setPeerClassLoadingEnabled(true); + IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName()); + } + + /** Tests on the MNIST dataset. For manual run. */ + public void testMNIST() throws IOException { + CacheConfiguration<Integer, MnistUtils.MnistLabeledImage> trainingSetCacheCfg = new CacheConfiguration<>(); + trainingSetCacheCfg.setAffinity(new RendezvousAffinityFunction(false, 10)); + trainingSetCacheCfg.setName("MNIST_TRAINING_SET"); + + IgniteCache<Integer, MnistUtils.MnistLabeledImage> trainingSet = ignite.createCache(trainingSetCacheCfg); + + int i = 0; + for (MnistUtils.MnistLabeledImage e : MnistMLPTestUtil.loadTrainingSet(60_000)) + trainingSet.put(i++, e); + + DecisionTreeClassificationTrainer trainer = new DecisionTreeClassificationTrainer( + 8, + 0, + new SimpleStepFunctionCompressor<>()); + + DecisionTreeNode mdl = trainer.fit( + new CacheBasedDatasetBuilder<>(ignite, trainingSet), + (k, v) -> v.getPixels(), + (k, v) -> (double) v.getLabel() + ); + + int correctAnswers = 0; + int incorrectAnswers = 0; + + for (MnistUtils.MnistLabeledImage e : MnistMLPTestUtil.loadTestSet(10_000)) { + double res = mdl.apply(e.getPixels()); + + if (res == e.getLabel()) + correctAnswers++; + else + incorrectAnswers++; + } + + double accuracy = 1.0 * correctAnswers / (correctAnswers + incorrectAnswers); + + assertTrue(accuracy > 0.8); + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/9abfee69/modules/ml/src/test/java/org/apache/ignite/ml/tree/performance/DecisionTreeMNISTTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/tree/performance/DecisionTreeMNISTTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/tree/performance/DecisionTreeMNISTTest.java new file mode 100644 index 0000000..6dbd44c --- /dev/null +++ b/modules/ml/src/test/java/org/apache/ignite/ml/tree/performance/DecisionTreeMNISTTest.java @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.tree.performance; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; +import org.apache.ignite.ml.dataset.impl.local.LocalDatasetBuilder; +import org.apache.ignite.ml.nn.performance.MnistMLPTestUtil; +import org.apache.ignite.ml.tree.DecisionTreeClassificationTrainer; +import org.apache.ignite.ml.tree.DecisionTreeNode; +import org.apache.ignite.ml.tree.impurity.util.SimpleStepFunctionCompressor; +import org.apache.ignite.ml.util.MnistUtils; +import org.junit.Test; + +import static junit.framework.TestCase.assertTrue; + +/** + * Tests {@link DecisionTreeClassificationTrainer} on the MNIST dataset using locally stored data. For manual run. + */ +public class DecisionTreeMNISTTest { + /** Tests on the MNIST dataset. For manual run. */ + @Test + public void testMNIST() throws IOException { + Map<Integer, MnistUtils.MnistLabeledImage> trainingSet = new HashMap<>(); + + int i = 0; + for (MnistUtils.MnistLabeledImage e : MnistMLPTestUtil.loadTrainingSet(60_000)) + trainingSet.put(i++, e); + + + DecisionTreeClassificationTrainer trainer = new DecisionTreeClassificationTrainer( + 8, + 0, + new SimpleStepFunctionCompressor<>()); + + DecisionTreeNode mdl = trainer.fit( + new LocalDatasetBuilder<>(trainingSet, 10), + (k, v) -> v.getPixels(), + (k, v) -> (double) v.getLabel() + ); + + int correctAnswers = 0; + int incorrectAnswers = 0; + + for (MnistUtils.MnistLabeledImage e : MnistMLPTestUtil.loadTestSet(10_000)) { + double res = mdl.apply(e.getPixels()); + + if (res == e.getLabel()) + correctAnswers++; + else + incorrectAnswers++; + } + + double accuracy = 1.0 * correctAnswers / (correctAnswers + incorrectAnswers); + + assertTrue(accuracy > 0.8); + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/9abfee69/modules/ml/src/test/java/org/apache/ignite/ml/trees/BaseDecisionTreeTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/trees/BaseDecisionTreeTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/trees/BaseDecisionTreeTest.java deleted file mode 100644 index 65f0ae4..0000000 --- a/modules/ml/src/test/java/org/apache/ignite/ml/trees/BaseDecisionTreeTest.java +++ /dev/null @@ -1,70 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.ignite.ml.trees; - -import java.util.Arrays; -import org.apache.ignite.Ignite; -import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector; -import org.apache.ignite.ml.structures.LabeledVectorDouble; -import org.apache.ignite.testframework.junits.common.GridCommonAbstractTest; - -/** - * Base class for decision trees test. - */ -public class BaseDecisionTreeTest extends GridCommonAbstractTest { - /** Count of nodes. */ - private static final int NODE_COUNT = 4; - - /** Grid instance. */ - protected Ignite ignite; - - /** - * Default constructor. - */ - public BaseDecisionTreeTest() { - super(false); - } - - /** - * {@inheritDoc} - */ - @Override protected void beforeTest() throws Exception { - ignite = grid(NODE_COUNT); - } - - /** {@inheritDoc} */ - @Override protected void beforeTestsStarted() throws Exception { - for (int i = 1; i <= NODE_COUNT; i++) - startGrid(i); - } - - /** {@inheritDoc} */ - @Override protected void afterTestsStopped() throws Exception { - stopAllGrids(); - } - - /** - * Convert double array to {@link LabeledVectorDouble} - * - * @param arr Array for conversion. - * @return LabeledVectorDouble. - */ - protected static LabeledVectorDouble<DenseLocalOnHeapVector> asLabeledVector(double arr[]) { - return new LabeledVectorDouble<>(new DenseLocalOnHeapVector(Arrays.copyOf(arr, arr.length - 1)), arr[arr.length - 1]); - } -} http://git-wip-us.apache.org/repos/asf/ignite/blob/9abfee69/modules/ml/src/test/java/org/apache/ignite/ml/trees/ColumnDecisionTreeTrainerTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/trees/ColumnDecisionTreeTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/trees/ColumnDecisionTreeTrainerTest.java deleted file mode 100644 index b090f43..0000000 --- a/modules/ml/src/test/java/org/apache/ignite/ml/trees/ColumnDecisionTreeTrainerTest.java +++ /dev/null @@ -1,191 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.ignite.ml.trees; - -import java.util.Collections; -import java.util.HashMap; -import java.util.LinkedList; -import java.util.List; -import java.util.Map; -import java.util.Random; -import java.util.stream.Collectors; -import java.util.stream.DoubleStream; -import org.apache.ignite.internal.util.IgniteUtils; -import org.apache.ignite.internal.util.typedef.X; -import org.apache.ignite.lang.IgniteBiTuple; -import org.apache.ignite.ml.math.StorageConstants; -import org.apache.ignite.ml.math.Tracer; -import org.apache.ignite.ml.math.functions.IgniteFunction; -import org.apache.ignite.ml.math.impls.matrix.SparseDistributedMatrix; -import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector; -import org.apache.ignite.ml.structures.LabeledVectorDouble; -import org.apache.ignite.ml.trees.models.DecisionTreeModel; -import org.apache.ignite.ml.trees.trainers.columnbased.ColumnDecisionTreeTrainer; -import org.apache.ignite.ml.trees.trainers.columnbased.ColumnDecisionTreeTrainerInput; -import org.apache.ignite.ml.trees.trainers.columnbased.MatrixColumnDecisionTreeTrainerInput; -import org.apache.ignite.ml.trees.trainers.columnbased.contsplitcalcs.ContinuousSplitCalculators; -import org.apache.ignite.ml.trees.trainers.columnbased.regcalcs.RegionCalculators; - -/** Tests behaviour of ColumnDecisionTreeTrainer. */ -public class ColumnDecisionTreeTrainerTest extends BaseDecisionTreeTest { - /** - * Test {@link ColumnDecisionTreeTrainerTest} for mixed (continuous and categorical) data with Gini impurity. - */ - public void testCacheMixedGini() { - IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName()); - int totalPts = 1 << 10; - int featCnt = 2; - - HashMap<Integer, Integer> catsInfo = new HashMap<>(); - catsInfo.put(1, 3); - - Random rnd = new Random(12349L); - - SplitDataGenerator<DenseLocalOnHeapVector> gen = new SplitDataGenerator<>( - featCnt, catsInfo, () -> new DenseLocalOnHeapVector(featCnt + 1), rnd). - split(0, 1, new int[] {0, 2}). - split(1, 0, -10.0); - - testByGen(totalPts, catsInfo, gen, ContinuousSplitCalculators.GINI.apply(ignite), RegionCalculators.GINI, RegionCalculators.MEAN, rnd); - } - - /** - * Test {@link ColumnDecisionTreeTrainerTest} for mixed (continuous and categorical) data with Variance impurity. - */ - public void testCacheMixed() { - IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName()); - int totalPts = 1 << 10; - int featCnt = 2; - - HashMap<Integer, Integer> catsInfo = new HashMap<>(); - catsInfo.put(1, 3); - - Random rnd = new Random(12349L); - - SplitDataGenerator<DenseLocalOnHeapVector> gen = new SplitDataGenerator<>( - featCnt, catsInfo, () -> new DenseLocalOnHeapVector(featCnt + 1), rnd). - split(0, 1, new int[] {0, 2}). - split(1, 0, -10.0); - - testByGen(totalPts, catsInfo, gen, ContinuousSplitCalculators.VARIANCE, RegionCalculators.VARIANCE, RegionCalculators.MEAN, rnd); - } - - /** - * Test {@link ColumnDecisionTreeTrainerTest} for continuous data with Variance impurity. - */ - public void testCacheCont() { - IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName()); - int totalPts = 1 << 10; - int featCnt = 12; - - HashMap<Integer, Integer> catsInfo = new HashMap<>(); - - Random rnd = new Random(12349L); - - SplitDataGenerator<DenseLocalOnHeapVector> gen = new SplitDataGenerator<>( - featCnt, catsInfo, () -> new DenseLocalOnHeapVector(featCnt + 1), rnd). - split(0, 0, -10.0). - split(1, 0, 0.0). - split(1, 1, 2.0). - split(3, 7, 50.0); - - testByGen(totalPts, catsInfo, gen, ContinuousSplitCalculators.VARIANCE, RegionCalculators.VARIANCE, RegionCalculators.MEAN, rnd); - } - - /** - * Test {@link ColumnDecisionTreeTrainerTest} for continuous data with Gini impurity. - */ - public void testCacheContGini() { - IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName()); - int totalPts = 1 << 10; - int featCnt = 12; - - HashMap<Integer, Integer> catsInfo = new HashMap<>(); - - Random rnd = new Random(12349L); - - SplitDataGenerator<DenseLocalOnHeapVector> gen = new SplitDataGenerator<>( - featCnt, catsInfo, () -> new DenseLocalOnHeapVector(featCnt + 1), rnd). - split(0, 0, -10.0). - split(1, 0, 0.0). - split(1, 1, 2.0). - split(3, 7, 50.0); - - testByGen(totalPts, catsInfo, gen, ContinuousSplitCalculators.GINI.apply(ignite), RegionCalculators.GINI, RegionCalculators.MEAN, rnd); - } - - /** - * Test {@link ColumnDecisionTreeTrainerTest} for categorical data with Variance impurity. - */ - public void testCacheCat() { - IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName()); - int totalPts = 1 << 10; - int featCnt = 12; - - HashMap<Integer, Integer> catsInfo = new HashMap<>(); - catsInfo.put(5, 7); - - Random rnd = new Random(12349L); - - SplitDataGenerator<DenseLocalOnHeapVector> gen = new SplitDataGenerator<>( - featCnt, catsInfo, () -> new DenseLocalOnHeapVector(featCnt + 1), rnd). - split(0, 5, new int[] {0, 2, 5}); - - testByGen(totalPts, catsInfo, gen, ContinuousSplitCalculators.VARIANCE, RegionCalculators.VARIANCE, RegionCalculators.MEAN, rnd); - } - - /** */ - private <D extends ContinuousRegionInfo> void testByGen(int totalPts, HashMap<Integer, Integer> catsInfo, - SplitDataGenerator<DenseLocalOnHeapVector> gen, - IgniteFunction<ColumnDecisionTreeTrainerInput, ? extends ContinuousSplitCalculator<D>> calc, - IgniteFunction<ColumnDecisionTreeTrainerInput, IgniteFunction<DoubleStream, Double>> catImpCalc, - IgniteFunction<DoubleStream, Double> regCalc, Random rnd) { - - List<IgniteBiTuple<Integer, DenseLocalOnHeapVector>> lst = gen. - points(totalPts, (i, rn) -> i). - collect(Collectors.toList()); - - int featCnt = gen.featuresCnt(); - - Collections.shuffle(lst, rnd); - - SparseDistributedMatrix m = new SparseDistributedMatrix(totalPts, featCnt + 1, StorageConstants.COLUMN_STORAGE_MODE, StorageConstants.RANDOM_ACCESS_MODE); - - Map<Integer, List<LabeledVectorDouble>> byRegion = new HashMap<>(); - - int i = 0; - for (IgniteBiTuple<Integer, DenseLocalOnHeapVector> bt : lst) { - byRegion.putIfAbsent(bt.get1(), new LinkedList<>()); - byRegion.get(bt.get1()).add(asLabeledVector(bt.get2().getStorage().data())); - m.setRow(i, bt.get2().getStorage().data()); - i++; - } - - ColumnDecisionTreeTrainer<D> trainer = - new ColumnDecisionTreeTrainer<>(3, calc, catImpCalc, regCalc, ignite); - - DecisionTreeModel mdl = trainer.train(new MatrixColumnDecisionTreeTrainerInput(m, catsInfo)); - - byRegion.keySet().forEach(k -> { - LabeledVectorDouble sp = byRegion.get(k).get(0); - Tracer.showAscii(sp.features()); - X.println("Actual and predicted vectors [act=" + sp.label() + " " + ", pred=" + mdl.apply(sp.features()) + "]"); - assert mdl.apply(sp.features()) == sp.doubleLabel(); - }); - } -} http://git-wip-us.apache.org/repos/asf/ignite/blob/9abfee69/modules/ml/src/test/java/org/apache/ignite/ml/trees/DecisionTreesTestSuite.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/trees/DecisionTreesTestSuite.java b/modules/ml/src/test/java/org/apache/ignite/ml/trees/DecisionTreesTestSuite.java deleted file mode 100644 index 3343503..0000000 --- a/modules/ml/src/test/java/org/apache/ignite/ml/trees/DecisionTreesTestSuite.java +++ /dev/null @@ -1,33 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.ignite.ml.trees; - -import org.junit.runner.RunWith; -import org.junit.runners.Suite; - -/** - * Test suite for all tests located in org.apache.ignite.ml.trees package - */ -@RunWith(Suite.class) -@Suite.SuiteClasses({ - ColumnDecisionTreeTrainerTest.class, - GiniSplitCalculatorTest.class, - VarianceSplitCalculatorTest.class -}) -public class DecisionTreesTestSuite { -} http://git-wip-us.apache.org/repos/asf/ignite/blob/9abfee69/modules/ml/src/test/java/org/apache/ignite/ml/trees/GiniSplitCalculatorTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/trees/GiniSplitCalculatorTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/trees/GiniSplitCalculatorTest.java deleted file mode 100644 index c92b4f5..0000000 --- a/modules/ml/src/test/java/org/apache/ignite/ml/trees/GiniSplitCalculatorTest.java +++ /dev/null @@ -1,141 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.ignite.ml.trees; - -import java.util.stream.DoubleStream; -import org.apache.ignite.ml.trees.trainers.columnbased.contsplitcalcs.GiniSplitCalculator; -import org.apache.ignite.ml.trees.trainers.columnbased.vectors.SplitInfo; -import org.junit.Test; - -/** - * Test of {@link GiniSplitCalculator}. - */ -public class GiniSplitCalculatorTest { - /** Test calculation of region info consisting from one point. */ - @Test - public void testCalculateRegionInfoSimple() { - double labels[] = new double[] {0.0}; - - assert new GiniSplitCalculator(labels).calculateRegionInfo(DoubleStream.of(labels), 0).impurity() == 0.0; - } - - /** Test calculation of region info consisting from two distinct classes. */ - @Test - public void testCalculateRegionInfoTwoClasses() { - double labels[] = new double[] {0.0, 1.0}; - - assert new GiniSplitCalculator(labels).calculateRegionInfo(DoubleStream.of(labels), 0).impurity() == 0.5; - } - - /** Test calculation of region info consisting from three distinct classes. */ - @Test - public void testCalculateRegionInfoThreeClasses() { - double labels[] = new double[] {0.0, 1.0, 2.0}; - - assert Math.abs(new GiniSplitCalculator(labels).calculateRegionInfo(DoubleStream.of(labels), 0).impurity() - 2.0 / 3) < 1E-5; - } - - /** Test calculation of split of region consisting from one point. */ - @Test - public void testSplitSimple() { - double labels[] = new double[] {0.0}; - double values[] = new double[] {0.0}; - Integer[] samples = new Integer[] {0}; - - int cnts[] = new int[] {1}; - - GiniSplitCalculator.GiniData data = new GiniSplitCalculator.GiniData(0.0, 1, cnts, 1); - - assert new GiniSplitCalculator(labels).splitRegion(samples, values, labels, 0, data) == null; - } - - /** Test calculation of split of region consisting from two points. */ - @Test - public void testSplitTwoClassesTwoPoints() { - double labels[] = new double[] {0.0, 1.0}; - double values[] = new double[] {0.0, 1.0}; - Integer[] samples = new Integer[] {0, 1}; - - int cnts[] = new int[] {1, 1}; - - GiniSplitCalculator.GiniData data = new GiniSplitCalculator.GiniData(0.5, 2, cnts, 1.0 * 1.0 + 1.0 * 1.0); - - SplitInfo<GiniSplitCalculator.GiniData> split = new GiniSplitCalculator(labels).splitRegion(samples, values, labels, 0, data); - - assert split.leftData().impurity() == 0; - assert split.leftData().counts()[0] == 1; - assert split.leftData().counts()[1] == 0; - assert split.leftData().getSize() == 1; - - assert split.rightData().impurity() == 0; - assert split.rightData().counts()[0] == 0; - assert split.rightData().counts()[1] == 1; - assert split.rightData().getSize() == 1; - } - - /** Test calculation of split of region consisting from four distinct values. */ - @Test - public void testSplitTwoClassesFourPoints() { - double labels[] = new double[] {0.0, 0.0, 1.0, 1.0}; - double values[] = new double[] {0.0, 1.0, 2.0, 3.0}; - - Integer[] samples = new Integer[] {0, 1, 2, 3}; - - int[] cnts = new int[] {2, 2}; - - GiniSplitCalculator.GiniData data = new GiniSplitCalculator.GiniData(0.5, 4, cnts, 2.0 * 2.0 + 2.0 * 2.0); - - SplitInfo<GiniSplitCalculator.GiniData> split = new GiniSplitCalculator(labels).splitRegion(samples, values, labels, 0, data); - - assert split.leftData().impurity() == 0; - assert split.leftData().counts()[0] == 2; - assert split.leftData().counts()[1] == 0; - assert split.leftData().getSize() == 2; - - assert split.rightData().impurity() == 0; - assert split.rightData().counts()[0] == 0; - assert split.rightData().counts()[1] == 2; - assert split.rightData().getSize() == 2; - } - - /** Test calculation of split of region consisting from three distinct values. */ - @Test - public void testSplitThreePoints() { - double labels[] = new double[] {0.0, 1.0, 2.0}; - double values[] = new double[] {0.0, 1.0, 2.0}; - Integer[] samples = new Integer[] {0, 1, 2}; - - int[] cnts = new int[] {1, 1, 1}; - - GiniSplitCalculator.GiniData data = new GiniSplitCalculator.GiniData(2.0 / 3, 3, cnts, 1.0 * 1.0 + 1.0 * 1.0 + 1.0 * 1.0); - - SplitInfo<GiniSplitCalculator.GiniData> split = new GiniSplitCalculator(labels).splitRegion(samples, values, labels, 0, data); - - assert split.leftData().impurity() == 0.0; - assert split.leftData().counts()[0] == 1; - assert split.leftData().counts()[1] == 0; - assert split.leftData().counts()[2] == 0; - assert split.leftData().getSize() == 1; - - assert split.rightData().impurity() == 0.5; - assert split.rightData().counts()[0] == 0; - assert split.rightData().counts()[1] == 1; - assert split.rightData().counts()[2] == 1; - assert split.rightData().getSize() == 2; - } -} http://git-wip-us.apache.org/repos/asf/ignite/blob/9abfee69/modules/ml/src/test/java/org/apache/ignite/ml/trees/SplitDataGenerator.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/trees/SplitDataGenerator.java b/modules/ml/src/test/java/org/apache/ignite/ml/trees/SplitDataGenerator.java deleted file mode 100644 index 279e685..0000000 --- a/modules/ml/src/test/java/org/apache/ignite/ml/trees/SplitDataGenerator.java +++ /dev/null @@ -1,390 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.ignite.ml.trees; - -import java.io.Serializable; -import java.util.Arrays; -import java.util.BitSet; -import java.util.HashMap; -import java.util.LinkedList; -import java.util.List; -import java.util.Map; -import java.util.Random; -import java.util.function.BiFunction; -import java.util.function.Function; -import java.util.function.Supplier; -import java.util.stream.Collectors; -import java.util.stream.IntStream; -import java.util.stream.Stream; -import org.apache.ignite.lang.IgniteBiTuple; -import org.apache.ignite.ml.math.Vector; -import org.apache.ignite.ml.math.exceptions.MathIllegalArgumentException; -import org.apache.ignite.ml.util.Utils; - -/** - * Utility class for generating data which has binary tree split structure. - * - * @param <V> - */ -public class SplitDataGenerator<V extends Vector> { - /** */ - private static final double DELTA = 100.0; - - /** Map of the form of (is categorical -> list of region indexes). */ - private final Map<Boolean, List<Integer>> di; - - /** List of regions. */ - private final List<Region> regs; - - /** Data of bounds of regions. */ - private final Map<Integer, IgniteBiTuple<Double, Double>> boundsData; - - /** Random numbers generator. */ - private final Random rnd; - - /** Supplier of vectors. */ - private final Supplier<V> supplier; - - /** Features count. */ - private final int featCnt; - - /** - * Create SplitDataGenerator. - * - * @param featCnt Features count. - * @param catFeaturesInfo Information about categorical features in form of map (feature index -> categories - * count). - * @param supplier Supplier of vectors. - * @param rnd Random numbers generator. - */ - public SplitDataGenerator(int featCnt, Map<Integer, Integer> catFeaturesInfo, Supplier<V> supplier, Random rnd) { - regs = new LinkedList<>(); - boundsData = new HashMap<>(); - this.rnd = rnd; - this.supplier = supplier; - this.featCnt = featCnt; - - // Divide indexes into indexes of categorical coordinates and indexes of continuous coordinates. - di = IntStream.range(0, featCnt). - boxed(). - collect(Collectors.partitioningBy(catFeaturesInfo::containsKey)); - - // Categorical coordinates info. - Map<Integer, CatCoordInfo> catCoords = new HashMap<>(); - di.get(true).forEach(i -> { - BitSet bs = new BitSet(); - bs.set(0, catFeaturesInfo.get(i)); - catCoords.put(i, new CatCoordInfo(bs)); - }); - - // Continuous coordinates info. - Map<Integer, ContCoordInfo> contCoords = new HashMap<>(); - di.get(false).forEach(i -> { - contCoords.put(i, new ContCoordInfo()); - boundsData.put(i, new IgniteBiTuple<>(-1.0, 1.0)); - }); - - Region firstReg = new Region(catCoords, contCoords, 0); - regs.add(firstReg); - } - - /** - * Categorical coordinate info. - */ - private static class CatCoordInfo implements Serializable { - /** - * Defines categories which are included in this region - */ - private final BitSet bs; - - /** - * Construct CatCoordInfo. - * - * @param bs Bitset. - */ - CatCoordInfo(BitSet bs) { - this.bs = bs; - } - - /** {@inheritDoc} */ - @Override public String toString() { - return "CatCoordInfo [" + - "bs=" + bs + - ']'; - } - } - - /** - * Continuous coordinate info. - */ - private static class ContCoordInfo implements Serializable { - /** - * Left (min) bound of region. - */ - private double left; - - /** - * Right (max) bound of region. - */ - private double right; - - /** - * Construct ContCoordInfo. - */ - ContCoordInfo() { - left = Double.NEGATIVE_INFINITY; - right = Double.POSITIVE_INFINITY; - } - - /** {@inheritDoc} */ - @Override public String toString() { - return "ContCoordInfo [" + - "left=" + left + - ", right=" + right + - ']'; - } - } - - /** - * Class representing information about region. - */ - private static class Region implements Serializable { - /** - * Information about categorical coordinates restrictions of this region in form of - * (coordinate index -> restriction) - */ - private final Map<Integer, CatCoordInfo> catCoords; - - /** - * Information about continuous coordinates restrictions of this region in form of - * (coordinate index -> restriction) - */ - private final Map<Integer, ContCoordInfo> contCoords; - - /** - * Region should contain {@code 1/2^twoPow * totalPoints} points. - */ - private int twoPow; - - /** - * Construct region by information about restrictions on coordinates (features) values. - * - * @param catCoords Restrictions on categorical coordinates. - * @param contCoords Restrictions on continuous coordinates - * @param twoPow Region should contain {@code 1/2^twoPow * totalPoints} points. - */ - Region(Map<Integer, CatCoordInfo> catCoords, Map<Integer, ContCoordInfo> contCoords, int twoPow) { - this.catCoords = catCoords; - this.contCoords = contCoords; - this.twoPow = twoPow; - } - - /** */ - int divideBy() { - return 1 << twoPow; - } - - /** */ - void incTwoPow() { - twoPow++; - } - - /** {@inheritDoc} */ - @Override public String toString() { - return "Region [" + - "catCoords=" + catCoords + - ", contCoords=" + contCoords + - ", twoPow=" + twoPow + - ']'; - } - - /** - * Generate continuous coordinate for this region. - * - * @param coordIdx Coordinate index. - * @param boundsData Data with bounds - * @param rnd Random numbers generator. - * @return Categorical coordinate value. - */ - double generateContCoord(int coordIdx, Map<Integer, IgniteBiTuple<Double, Double>> boundsData, - Random rnd) { - ContCoordInfo cci = contCoords.get(coordIdx); - double left = cci.left; - double right = cci.right; - - if (left == Double.NEGATIVE_INFINITY) - left = boundsData.get(coordIdx).get1() - DELTA; - - if (right == Double.POSITIVE_INFINITY) - right = boundsData.get(coordIdx).get2() + DELTA; - - double size = right - left; - - return left + rnd.nextDouble() * size; - } - - /** - * Generate categorical coordinate value for this region. - * - * @param coordIdx Coordinate index. - * @param rnd Random numbers generator. - * @return Categorical coordinate value. - */ - double generateCatCoord(int coordIdx, Random rnd) { - // Pick random bit. - BitSet bs = catCoords.get(coordIdx).bs; - int j = rnd.nextInt(bs.length()); - - int i = 0; - int bn = 0; - int bnp = 0; - - while ((bn = bs.nextSetBit(bn)) != -1 && i <= j) { - i++; - bnp = bn; - bn++; - } - - return bnp; - } - - /** - * Generate points for this region. - * - * @param ptsCnt Count of points to generate. - * @param val Label for all points in this region. - * @param boundsData Data about bounds of continuous coordinates. - * @param catCont Data about which categories can be in this region in the form (coordinate index -> list of - * categories indexes). - * @param s Vectors supplier. - * @param rnd Random numbers generator. - * @param <V> Type of vectors. - * @return Stream of generated points for this region. - */ - <V extends Vector> Stream<V> generatePoints(int ptsCnt, double val, - Map<Integer, IgniteBiTuple<Double, Double>> boundsData, Map<Boolean, List<Integer>> catCont, - Supplier<V> s, - Random rnd) { - return IntStream.range(0, ptsCnt / divideBy()).mapToObj(i -> { - V v = s.get(); - int coordsCnt = v.size(); - catCont.get(false).forEach(ci -> v.setX(ci, generateContCoord(ci, boundsData, rnd))); - catCont.get(true).forEach(ci -> v.setX(ci, generateCatCoord(ci, rnd))); - - v.setX(coordsCnt - 1, val); - return v; - }); - } - } - - /** - * Split region by continuous coordinate.using given threshold. - * - * @param regIdx Region index. - * @param coordIdx Coordinate index. - * @param threshold Threshold. - * @return {@code this}. - */ - public SplitDataGenerator<V> split(int regIdx, int coordIdx, double threshold) { - Region regToSplit = regs.get(regIdx); - ContCoordInfo cci = regToSplit.contCoords.get(coordIdx); - - double left = cci.left; - double right = cci.right; - - if (threshold < left || threshold > right) - throw new MathIllegalArgumentException("Threshold is out of region bounds."); - - regToSplit.incTwoPow(); - - Region newReg = Utils.copy(regToSplit); - newReg.contCoords.get(coordIdx).left = threshold; - - regs.add(regIdx + 1, newReg); - cci.right = threshold; - - IgniteBiTuple<Double, Double> bounds = boundsData.get(coordIdx); - double min = bounds.get1(); - double max = bounds.get2(); - boundsData.put(coordIdx, new IgniteBiTuple<>(Math.min(threshold, min), Math.max(max, threshold))); - - return this; - } - - /** - * Split region by categorical coordinate. - * - * @param regIdx Region index. - * @param coordIdx Coordinate index. - * @param cats Categories allowed for the left sub region. - * @return {@code this}. - */ - public SplitDataGenerator<V> split(int regIdx, int coordIdx, int[] cats) { - BitSet subset = new BitSet(); - Arrays.stream(cats).forEach(subset::set); - Region regToSplit = regs.get(regIdx); - CatCoordInfo cci = regToSplit.catCoords.get(coordIdx); - - BitSet ssc = (BitSet)subset.clone(); - BitSet set = cci.bs; - ssc.and(set); - if (ssc.length() != subset.length()) - throw new MathIllegalArgumentException("Splitter set is not a subset of a parent subset."); - - ssc.xor(set); - set.and(subset); - - regToSplit.incTwoPow(); - Region newReg = Utils.copy(regToSplit); - newReg.catCoords.put(coordIdx, new CatCoordInfo(ssc)); - - regs.add(regIdx + 1, newReg); - - return this; - } - - /** - * Get stream of points generated by this generator. - * - * @param ptsCnt Points count. - */ - public Stream<IgniteBiTuple<Integer, V>> points(int ptsCnt, BiFunction<Double, Random, Double> f) { - regs.forEach(System.out::println); - - return IntStream.range(0, regs.size()). - boxed(). - map(i -> regs.get(i).generatePoints(ptsCnt, f.apply((double)i, rnd), boundsData, di, supplier, rnd).map(v -> new IgniteBiTuple<>(i, v))).flatMap(Function.identity()); - } - - /** - * Count of regions. - * - * @return Count of regions. - */ - public int regsCount() { - return regs.size(); - } - - /** - * Get features count. - * - * @return Features count. - */ - public int featuresCnt() { - return featCnt; - } -} http://git-wip-us.apache.org/repos/asf/ignite/blob/9abfee69/modules/ml/src/test/java/org/apache/ignite/ml/trees/VarianceSplitCalculatorTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/trees/VarianceSplitCalculatorTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/trees/VarianceSplitCalculatorTest.java deleted file mode 100644 index d67cbc6..0000000 --- a/modules/ml/src/test/java/org/apache/ignite/ml/trees/VarianceSplitCalculatorTest.java +++ /dev/null @@ -1,84 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.ignite.ml.trees; - -import java.util.stream.DoubleStream; -import org.apache.ignite.ml.trees.trainers.columnbased.contsplitcalcs.VarianceSplitCalculator; -import org.apache.ignite.ml.trees.trainers.columnbased.vectors.SplitInfo; -import org.junit.Test; - -/** - * Test for {@link VarianceSplitCalculator}. - */ -public class VarianceSplitCalculatorTest { - /** Test calculation of region info consisting from one point. */ - @Test - public void testCalculateRegionInfoSimple() { - double labels[] = new double[] {0.0}; - - assert new VarianceSplitCalculator().calculateRegionInfo(DoubleStream.of(labels), 1).impurity() == 0.0; - } - - /** Test calculation of region info consisting from two classes. */ - @Test - public void testCalculateRegionInfoTwoClasses() { - double labels[] = new double[] {0.0, 1.0}; - - assert new VarianceSplitCalculator().calculateRegionInfo(DoubleStream.of(labels), 2).impurity() == 0.25; - } - - /** Test calculation of region info consisting from three classes. */ - @Test - public void testCalculateRegionInfoThreeClasses() { - double labels[] = new double[] {1.0, 2.0, 3.0}; - - assert Math.abs(new VarianceSplitCalculator().calculateRegionInfo(DoubleStream.of(labels), 3).impurity() - 2.0 / 3) < 1E-10; - } - - /** Test calculation of split of region consisting from one point. */ - @Test - public void testSplitSimple() { - double labels[] = new double[] {0.0}; - double values[] = new double[] {0.0}; - Integer[] samples = new Integer[] {0}; - - VarianceSplitCalculator.VarianceData data = new VarianceSplitCalculator.VarianceData(0.0, 1, 0.0); - - assert new VarianceSplitCalculator().splitRegion(samples, values, labels, 0, data) == null; - } - - /** Test calculation of split of region consisting from two classes. */ - @Test - public void testSplitTwoClassesTwoPoints() { - double labels[] = new double[] {0.0, 1.0}; - double values[] = new double[] {0.0, 1.0}; - Integer[] samples = new Integer[] {0, 1}; - - VarianceSplitCalculator.VarianceData data = new VarianceSplitCalculator.VarianceData(0.25, 2, 0.5); - - SplitInfo<VarianceSplitCalculator.VarianceData> split = new VarianceSplitCalculator().splitRegion(samples, values, labels, 0, data); - - assert split.leftData().impurity() == 0; - assert split.leftData().mean() == 0; - assert split.leftData().getSize() == 1; - - assert split.rightData().impurity() == 0; - assert split.rightData().mean() == 1; - assert split.rightData().getSize() == 1; - } -} http://git-wip-us.apache.org/repos/asf/ignite/blob/9abfee69/modules/ml/src/test/java/org/apache/ignite/ml/trees/performance/ColumnDecisionTreeTrainerBenchmark.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/trees/performance/ColumnDecisionTreeTrainerBenchmark.java b/modules/ml/src/test/java/org/apache/ignite/ml/trees/performance/ColumnDecisionTreeTrainerBenchmark.java deleted file mode 100644 index 21fd692..0000000 --- a/modules/ml/src/test/java/org/apache/ignite/ml/trees/performance/ColumnDecisionTreeTrainerBenchmark.java +++ /dev/null @@ -1,456 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.ignite.ml.trees.performance; - -import it.unimi.dsi.fastutil.ints.Int2DoubleOpenHashMap; -import java.io.IOException; -import java.io.InputStream; -import java.util.Arrays; -import java.util.Collections; -import java.util.HashMap; -import java.util.Iterator; -import java.util.LinkedList; -import java.util.List; -import java.util.Map; -import java.util.Properties; -import java.util.Random; -import java.util.UUID; -import java.util.function.Function; -import java.util.stream.Collectors; -import java.util.stream.DoubleStream; -import java.util.stream.IntStream; -import java.util.stream.Stream; -import org.apache.ignite.IgniteCache; -import org.apache.ignite.IgniteDataStreamer; -import org.apache.ignite.Ignition; -import org.apache.ignite.cache.CacheAtomicityMode; -import org.apache.ignite.cache.CacheMode; -import org.apache.ignite.cache.CacheWriteSynchronizationMode; -import org.apache.ignite.configuration.CacheConfiguration; -import org.apache.ignite.configuration.IgniteConfiguration; -import org.apache.ignite.internal.processors.cache.GridCacheProcessor; -import org.apache.ignite.internal.util.IgniteUtils; -import org.apache.ignite.internal.util.typedef.X; -import org.apache.ignite.lang.IgniteBiTuple; -import org.apache.ignite.ml.Model; -import org.apache.ignite.ml.estimators.Estimators; -import org.apache.ignite.ml.math.StorageConstants; -import org.apache.ignite.ml.math.Tracer; -import org.apache.ignite.ml.math.Vector; -import org.apache.ignite.ml.math.distributed.keys.impl.SparseMatrixKey; -import org.apache.ignite.ml.math.functions.IgniteFunction; -import org.apache.ignite.ml.math.functions.IgniteTriFunction; -import org.apache.ignite.ml.math.impls.matrix.SparseDistributedMatrix; -import org.apache.ignite.ml.math.impls.storage.matrix.SparseDistributedMatrixStorage; -import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector; -import org.apache.ignite.ml.structures.LabeledVectorDouble; -import org.apache.ignite.ml.trees.BaseDecisionTreeTest; -import org.apache.ignite.ml.trees.SplitDataGenerator; -import org.apache.ignite.ml.trees.models.DecisionTreeModel; -import org.apache.ignite.ml.trees.trainers.columnbased.BiIndex; -import org.apache.ignite.ml.trees.trainers.columnbased.BiIndexedCacheColumnDecisionTreeTrainerInput; -import org.apache.ignite.ml.trees.trainers.columnbased.ColumnDecisionTreeTrainer; -import org.apache.ignite.ml.trees.trainers.columnbased.MatrixColumnDecisionTreeTrainerInput; -import org.apache.ignite.ml.trees.trainers.columnbased.caches.ContextCache; -import org.apache.ignite.ml.trees.trainers.columnbased.caches.FeaturesCache; -import org.apache.ignite.ml.trees.trainers.columnbased.caches.ProjectionsCache; -import org.apache.ignite.ml.trees.trainers.columnbased.caches.SplitCache; -import org.apache.ignite.ml.trees.trainers.columnbased.contsplitcalcs.ContinuousSplitCalculators; -import org.apache.ignite.ml.trees.trainers.columnbased.contsplitcalcs.GiniSplitCalculator; -import org.apache.ignite.ml.trees.trainers.columnbased.contsplitcalcs.VarianceSplitCalculator; -import org.apache.ignite.ml.trees.trainers.columnbased.regcalcs.RegionCalculators; -import org.apache.ignite.ml.util.MnistUtils; -import org.apache.ignite.stream.StreamTransformer; -import org.apache.ignite.testframework.junits.IgniteTestResources; -import org.apache.log4j.Level; -import org.junit.Assert; - -/** - * Various benchmarks for hand runs. - */ -public class ColumnDecisionTreeTrainerBenchmark extends BaseDecisionTreeTest { - /** Name of the property specifying path to training set images. */ - private static final String PROP_TRAINING_IMAGES = "mnist.training.images"; - - /** Name of property specifying path to training set labels. */ - private static final String PROP_TRAINING_LABELS = "mnist.training.labels"; - - /** Name of property specifying path to test set images. */ - private static final String PROP_TEST_IMAGES = "mnist.test.images"; - - /** Name of property specifying path to test set labels. */ - private static final String PROP_TEST_LABELS = "mnist.test.labels"; - - /** Function to approximate. */ - private static final Function<Vector, Double> f1 = v -> v.get(0) * v.get(0) + 2 * Math.sin(v.get(1)) + v.get(2); - - /** {@inheritDoc} */ - @Override protected long getTestTimeout() { - return 6000000; - } - - /** {@inheritDoc} */ - @Override protected IgniteConfiguration getConfiguration(String igniteInstanceName, - IgniteTestResources rsrcs) throws Exception { - IgniteConfiguration configuration = super.getConfiguration(igniteInstanceName, rsrcs); - // We do not need any extra event types. - configuration.setIncludeEventTypes(); - configuration.setPeerClassLoadingEnabled(false); - - resetLog4j(Level.INFO, false, GridCacheProcessor.class.getPackage().getName()); - - return configuration; - } - - /** - * This test is for manual run only. - * To run this test rename this method so it starts from 'test'. - */ - public void tstCacheMixed() { - IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName()); - int ptsPerReg = 150; - int featCnt = 10; - - HashMap<Integer, Integer> catsInfo = new HashMap<>(); - catsInfo.put(1, 3); - - Random rnd = new Random(12349L); - - SplitDataGenerator<DenseLocalOnHeapVector> gen = new SplitDataGenerator<>( - featCnt, catsInfo, () -> new DenseLocalOnHeapVector(featCnt + 1), rnd). - split(0, 1, new int[] {0, 2}). - split(1, 0, -10.0). - split(0, 0, 0.0); - - testByGenStreamerLoad(ptsPerReg, catsInfo, gen, rnd); - } - - /** - * Run decision tree classifier on MNIST using bi-indexed cache as a storage for dataset. - * To run this test rename this method so it starts from 'test'. - * - * @throws IOException In case of loading MNIST dataset errors. - */ - public void tstMNISTBiIndexedCache() throws IOException { - IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName()); - - int ptsCnt = 40_000; - int featCnt = 28 * 28; - - Properties props = loadMNISTProperties(); - - Stream<DenseLocalOnHeapVector> trainingMnistStream = MnistUtils.mnistAsStream(props.getProperty(PROP_TRAINING_IMAGES), props.getProperty(PROP_TRAINING_LABELS), new Random(123L), ptsCnt); - Stream<DenseLocalOnHeapVector> testMnistStream = MnistUtils.mnistAsStream(props.getProperty(PROP_TEST_IMAGES), props.getProperty(PROP_TEST_LABELS), new Random(123L), 10_000); - - IgniteCache<BiIndex, Double> cache = createBiIndexedCache(); - - loadVectorsIntoBiIndexedCache(cache.getName(), trainingMnistStream.iterator(), featCnt + 1); - - ColumnDecisionTreeTrainer<GiniSplitCalculator.GiniData> trainer = - new ColumnDecisionTreeTrainer<>(10, ContinuousSplitCalculators.GINI.apply(ignite), RegionCalculators.GINI, RegionCalculators.MOST_COMMON, ignite); - - X.println("Training started."); - long before = System.currentTimeMillis(); - DecisionTreeModel mdl = trainer.train(new BiIndexedCacheColumnDecisionTreeTrainerInput(cache, new HashMap<>(), ptsCnt, featCnt)); - X.println("Training finished in " + (System.currentTimeMillis() - before)); - - IgniteTriFunction<Model<Vector, Double>, Stream<IgniteBiTuple<Vector, Double>>, Function<Double, Double>, Double> mse = Estimators.errorsPercentage(); - Double accuracy = mse.apply(mdl, testMnistStream.map(v -> new IgniteBiTuple<>(v.viewPart(0, featCnt), v.getX(featCnt))), Function.identity()); - X.println("Errors percentage: " + accuracy); - - Assert.assertEquals(0, SplitCache.getOrCreate(ignite).size()); - Assert.assertEquals(0, FeaturesCache.getOrCreate(ignite).size()); - Assert.assertEquals(0, ContextCache.getOrCreate(ignite).size()); - Assert.assertEquals(0, ProjectionsCache.getOrCreate(ignite).size()); - } - - /** - * Run decision tree classifier on MNIST using sparse distributed matrix as a storage for dataset. - * To run this test rename this method so it starts from 'test'. - * - * @throws IOException In case of loading MNIST dataset errors. - */ - public void tstMNISTSparseDistributedMatrix() throws IOException { - IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName()); - - int ptsCnt = 30_000; - int featCnt = 28 * 28; - - Properties props = loadMNISTProperties(); - - Stream<DenseLocalOnHeapVector> trainingMnistStream = MnistUtils.mnistAsStream(props.getProperty(PROP_TRAINING_IMAGES), props.getProperty(PROP_TRAINING_LABELS), new Random(123L), ptsCnt); - Stream<DenseLocalOnHeapVector> testMnistStream = MnistUtils.mnistAsStream(props.getProperty(PROP_TEST_IMAGES), props.getProperty(PROP_TEST_LABELS), new Random(123L), 10_000); - - SparseDistributedMatrix m = new SparseDistributedMatrix(ptsCnt, featCnt + 1, StorageConstants.COLUMN_STORAGE_MODE, StorageConstants.RANDOM_ACCESS_MODE); - - SparseDistributedMatrixStorage sto = (SparseDistributedMatrixStorage)m.getStorage(); - - loadVectorsIntoSparseDistributedMatrixCache(sto.cache().getName(), sto.getUUID(), trainingMnistStream.iterator(), featCnt + 1); - - ColumnDecisionTreeTrainer<GiniSplitCalculator.GiniData> trainer = - new ColumnDecisionTreeTrainer<>(10, ContinuousSplitCalculators.GINI.apply(ignite), RegionCalculators.GINI, RegionCalculators.MOST_COMMON, ignite); - - X.println("Training started"); - long before = System.currentTimeMillis(); - DecisionTreeModel mdl = trainer.train(new MatrixColumnDecisionTreeTrainerInput(m, new HashMap<>())); - X.println("Training finished in " + (System.currentTimeMillis() - before)); - - IgniteTriFunction<Model<Vector, Double>, Stream<IgniteBiTuple<Vector, Double>>, Function<Double, Double>, Double> mse = Estimators.errorsPercentage(); - Double accuracy = mse.apply(mdl, testMnistStream.map(v -> new IgniteBiTuple<>(v.viewPart(0, featCnt), v.getX(featCnt))), Function.identity()); - X.println("Errors percentage: " + accuracy); - - Assert.assertEquals(0, SplitCache.getOrCreate(ignite).size()); - Assert.assertEquals(0, FeaturesCache.getOrCreate(ignite).size()); - Assert.assertEquals(0, ContextCache.getOrCreate(ignite).size()); - Assert.assertEquals(0, ProjectionsCache.getOrCreate(ignite).size()); - } - - /** Load properties for MNIST tests. */ - private static Properties loadMNISTProperties() throws IOException { - Properties res = new Properties(); - - InputStream is = ColumnDecisionTreeTrainerBenchmark.class.getClassLoader().getResourceAsStream("manualrun/trees/columntrees.manualrun.properties"); - - res.load(is); - - return res; - } - - /** */ - private void testByGenStreamerLoad(int ptsPerReg, HashMap<Integer, Integer> catsInfo, - SplitDataGenerator<DenseLocalOnHeapVector> gen, Random rnd) { - - List<IgniteBiTuple<Integer, DenseLocalOnHeapVector>> lst = gen. - points(ptsPerReg, (i, rn) -> i). - collect(Collectors.toList()); - - int featCnt = gen.featuresCnt(); - - Collections.shuffle(lst, rnd); - - int numRegs = gen.regsCount(); - - SparseDistributedMatrix m = new SparseDistributedMatrix(numRegs * ptsPerReg, featCnt + 1, StorageConstants.COLUMN_STORAGE_MODE, StorageConstants.RANDOM_ACCESS_MODE); - - IgniteFunction<DoubleStream, Double> regCalc = s -> s.average().orElse(0.0); - - Map<Integer, List<LabeledVectorDouble>> byRegion = new HashMap<>(); - - SparseDistributedMatrixStorage sto = (SparseDistributedMatrixStorage)m.getStorage(); - long before = System.currentTimeMillis(); - X.println("Batch loading started..."); - loadVectorsIntoSparseDistributedMatrixCache(sto.cache().getName(), sto.getUUID(), gen. - points(ptsPerReg, (i, rn) -> i).map(IgniteBiTuple::get2).iterator(), featCnt + 1); - X.println("Batch loading took " + (System.currentTimeMillis() - before) + " ms."); - - for (IgniteBiTuple<Integer, DenseLocalOnHeapVector> bt : lst) { - byRegion.putIfAbsent(bt.get1(), new LinkedList<>()); - byRegion.get(bt.get1()).add(asLabeledVector(bt.get2().getStorage().data())); - } - - ColumnDecisionTreeTrainer<VarianceSplitCalculator.VarianceData> trainer = - new ColumnDecisionTreeTrainer<>(2, ContinuousSplitCalculators.VARIANCE, RegionCalculators.VARIANCE, regCalc, ignite); - - before = System.currentTimeMillis(); - DecisionTreeModel mdl = trainer.train(new MatrixColumnDecisionTreeTrainerInput(m, catsInfo)); - - X.println("Training took: " + (System.currentTimeMillis() - before) + " ms."); - - byRegion.keySet().forEach(k -> { - LabeledVectorDouble sp = byRegion.get(k).get(0); - Tracer.showAscii(sp.features()); - X.println("Predicted value and label [pred=" + mdl.apply(sp.features()) + ", label=" + sp.doubleLabel() + "]"); - assert mdl.apply(sp.features()) == sp.doubleLabel(); - }); - } - - /** - * Test decision tree regression. - * To run this test rename this method so it starts from 'test'. - */ - public void tstF1() { - IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName()); - int ptsCnt = 10000; - Map<Integer, double[]> ranges = new HashMap<>(); - - ranges.put(0, new double[] {-100.0, 100.0}); - ranges.put(1, new double[] {-100.0, 100.0}); - ranges.put(2, new double[] {-100.0, 100.0}); - - int featCnt = 100; - double[] defRng = {-1.0, 1.0}; - - Vector[] trainVectors = vecsFromRanges(ranges, featCnt, defRng, new Random(123L), ptsCnt, f1); - - SparseDistributedMatrix m = new SparseDistributedMatrix(ptsCnt, featCnt + 1, StorageConstants.COLUMN_STORAGE_MODE, StorageConstants.RANDOM_ACCESS_MODE); - - SparseDistributedMatrixStorage sto = (SparseDistributedMatrixStorage)m.getStorage(); - - loadVectorsIntoSparseDistributedMatrixCache(sto.cache().getName(), sto.getUUID(), Arrays.stream(trainVectors).iterator(), featCnt + 1); - - IgniteFunction<DoubleStream, Double> regCalc = s -> s.average().orElse(0.0); - - ColumnDecisionTreeTrainer<VarianceSplitCalculator.VarianceData> trainer = - new ColumnDecisionTreeTrainer<>(10, ContinuousSplitCalculators.VARIANCE, RegionCalculators.VARIANCE, regCalc, ignite); - - X.println("Training started."); - long before = System.currentTimeMillis(); - DecisionTreeModel mdl = trainer.train(new MatrixColumnDecisionTreeTrainerInput(m, new HashMap<>())); - X.println("Training finished in: " + (System.currentTimeMillis() - before) + " ms."); - - Vector[] testVectors = vecsFromRanges(ranges, featCnt, defRng, new Random(123L), 20, f1); - - IgniteTriFunction<Model<Vector, Double>, Stream<IgniteBiTuple<Vector, Double>>, Function<Double, Double>, Double> mse = Estimators.MSE(); - Double accuracy = mse.apply(mdl, Arrays.stream(testVectors).map(v -> new IgniteBiTuple<>(v.viewPart(0, featCnt), v.getX(featCnt))), Function.identity()); - X.println("MSE: " + accuracy); - } - - /** - * Load vectors into sparse distributed matrix. - * - * @param cacheName Name of cache where matrix is stored. - * @param uuid UUID of matrix. - * @param iter Iterator over vectors. - * @param vectorSize size of vectors. - */ - private void loadVectorsIntoSparseDistributedMatrixCache(String cacheName, UUID uuid, - Iterator<? extends org.apache.ignite.ml.math.Vector> iter, int vectorSize) { - try (IgniteDataStreamer<SparseMatrixKey, Map<Integer, Double>> streamer = - Ignition.localIgnite().dataStreamer(cacheName)) { - int sampleIdx = 0; - streamer.allowOverwrite(true); - - streamer.receiver(StreamTransformer.from((e, arg) -> { - Map<Integer, Double> val = e.getValue(); - - if (val == null) - val = new Int2DoubleOpenHashMap(); - - val.putAll((Map<Integer, Double>)arg[0]); - - e.setValue(val); - - return null; - })); - - // Feature index -> (sample index -> value) - Map<Integer, Map<Integer, Double>> batch = new HashMap<>(); - IntStream.range(0, vectorSize).forEach(i -> batch.put(i, new HashMap<>())); - int batchSize = 1000; - - while (iter.hasNext()) { - org.apache.ignite.ml.math.Vector next = iter.next(); - - for (int i = 0; i < vectorSize; i++) - batch.get(i).put(sampleIdx, next.getX(i)); - - X.println("Sample index: " + sampleIdx); - if (sampleIdx % batchSize == 0) { - batch.keySet().forEach(fi -> streamer.addData(new SparseMatrixKey(fi, uuid, fi), batch.get(fi))); - IntStream.range(0, vectorSize).forEach(i -> batch.put(i, new HashMap<>())); - } - sampleIdx++; - } - if (sampleIdx % batchSize != 0) { - batch.keySet().forEach(fi -> streamer.addData(new SparseMatrixKey(fi, uuid, fi), batch.get(fi))); - IntStream.range(0, vectorSize).forEach(i -> batch.put(i, new HashMap<>())); - } - } - } - - /** - * Load vectors into bi-indexed cache. - * - * @param cacheName Name of cache. - * @param iter Iterator over vectors. - * @param vectorSize size of vectors. - */ - private void loadVectorsIntoBiIndexedCache(String cacheName, - Iterator<? extends org.apache.ignite.ml.math.Vector> iter, int vectorSize) { - try (IgniteDataStreamer<BiIndex, Double> streamer = - Ignition.localIgnite().dataStreamer(cacheName)) { - int sampleIdx = 0; - - streamer.perNodeBufferSize(10000); - - while (iter.hasNext()) { - org.apache.ignite.ml.math.Vector next = iter.next(); - - for (int i = 0; i < vectorSize; i++) - streamer.addData(new BiIndex(sampleIdx, i), next.getX(i)); - - sampleIdx++; - - if (sampleIdx % 1000 == 0) - System.out.println("Loaded: " + sampleIdx + " vectors."); - } - } - } - - /** - * Create bi-indexed cache for tests. - * - * @return Bi-indexed cache. - */ - private IgniteCache<BiIndex, Double> createBiIndexedCache() { - CacheConfiguration<BiIndex, Double> cfg = new CacheConfiguration<>(); - - // Write to primary. - cfg.setWriteSynchronizationMode(CacheWriteSynchronizationMode.PRIMARY_SYNC); - - // Atomic transactions only. - cfg.setAtomicityMode(CacheAtomicityMode.ATOMIC); - - // No eviction. - cfg.setEvictionPolicy(null); - - // No copying of values. - cfg.setCopyOnRead(false); - - // Cache is partitioned. - cfg.setCacheMode(CacheMode.PARTITIONED); - - cfg.setBackups(0); - - cfg.setName("TMP_BI_INDEXED_CACHE"); - - return Ignition.localIgnite().getOrCreateCache(cfg); - } - - /** */ - private Vector[] vecsFromRanges(Map<Integer, double[]> ranges, int featCnt, double[] defRng, Random rnd, int ptsCnt, - Function<Vector, Double> f) { - int vs = featCnt + 1; - DenseLocalOnHeapVector[] res = new DenseLocalOnHeapVector[ptsCnt]; - for (int pt = 0; pt < ptsCnt; pt++) { - DenseLocalOnHeapVector v = new DenseLocalOnHeapVector(vs); - for (int i = 0; i < featCnt; i++) { - double[] range = ranges.getOrDefault(i, defRng); - double from = range[0]; - double to = range[1]; - double rng = to - from; - - v.setX(i, rnd.nextDouble() * rng); - } - v.setX(featCnt, f.apply(v)); - res[pt] = v; - } - - return res; - } -} http://git-wip-us.apache.org/repos/asf/ignite/blob/9abfee69/modules/yardstick/src/main/java/org/apache/ignite/yardstick/ml/trees/IgniteColumnDecisionTreeGiniBenchmark.java ---------------------------------------------------------------------- diff --git a/modules/yardstick/src/main/java/org/apache/ignite/yardstick/ml/trees/IgniteColumnDecisionTreeGiniBenchmark.java b/modules/yardstick/src/main/java/org/apache/ignite/yardstick/ml/trees/IgniteColumnDecisionTreeGiniBenchmark.java deleted file mode 100644 index f8a7c08..0000000 --- a/modules/yardstick/src/main/java/org/apache/ignite/yardstick/ml/trees/IgniteColumnDecisionTreeGiniBenchmark.java +++ /dev/null @@ -1,70 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.ignite.yardstick.ml.trees; - -import java.util.HashMap; -import java.util.Map; -import org.apache.ignite.Ignite; -import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector; -import org.apache.ignite.ml.trees.trainers.columnbased.contsplitcalcs.ContinuousSplitCalculators; -import org.apache.ignite.ml.trees.trainers.columnbased.regcalcs.RegionCalculators; -import org.apache.ignite.resources.IgniteInstanceResource; -import org.apache.ignite.thread.IgniteThread; -import org.apache.ignite.yardstick.IgniteAbstractBenchmark; - -/** - * Ignite benchmark that performs ML Grid operations. - */ -@SuppressWarnings("unused") -public class IgniteColumnDecisionTreeGiniBenchmark extends IgniteAbstractBenchmark { - /** */ - @IgniteInstanceResource - private Ignite ignite; - - /** {@inheritDoc} */ - @Override public boolean test(Map<Object, Object> ctx) throws Exception { - // Create IgniteThread, we must work with SparseDistributedMatrix inside IgniteThread - // because we create ignite cache internally. - IgniteThread igniteThread = new IgniteThread(ignite.configuration().getIgniteInstanceName(), - this.getClass().getSimpleName(), new Runnable() { - /** {@inheritDoc} */ - @Override public void run() { - // IMPL NOTE originally taken from ColumnDecisionTreeTrainerTest#testCacheMixedGini - int totalPts = 1 << 10; - int featCnt = 2; - - HashMap<Integer, Integer> catsInfo = new HashMap<>(); - catsInfo.put(1, 3); - - SplitDataGenerator<DenseLocalOnHeapVector> gen = new SplitDataGenerator<>( - featCnt, catsInfo, () -> new DenseLocalOnHeapVector(featCnt + 1)). - split(0, 1, new int[] {0, 2}). - split(1, 0, -10.0); - - gen.testByGen(totalPts, ContinuousSplitCalculators.GINI.apply(ignite), - RegionCalculators.GINI, RegionCalculators.MEAN, ignite); - } - }); - - igniteThread.start(); - - igniteThread.join(); - - return true; - } -} http://git-wip-us.apache.org/repos/asf/ignite/blob/9abfee69/modules/yardstick/src/main/java/org/apache/ignite/yardstick/ml/trees/IgniteColumnDecisionTreeVarianceBenchmark.java ---------------------------------------------------------------------- diff --git a/modules/yardstick/src/main/java/org/apache/ignite/yardstick/ml/trees/IgniteColumnDecisionTreeVarianceBenchmark.java b/modules/yardstick/src/main/java/org/apache/ignite/yardstick/ml/trees/IgniteColumnDecisionTreeVarianceBenchmark.java deleted file mode 100644 index f9d417f..0000000 --- a/modules/yardstick/src/main/java/org/apache/ignite/yardstick/ml/trees/IgniteColumnDecisionTreeVarianceBenchmark.java +++ /dev/null @@ -1,71 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.ignite.yardstick.ml.trees; - -import java.util.HashMap; -import java.util.Map; -import org.apache.ignite.Ignite; -import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector; -import org.apache.ignite.ml.trees.trainers.columnbased.contsplitcalcs.ContinuousSplitCalculators; -import org.apache.ignite.ml.trees.trainers.columnbased.regcalcs.RegionCalculators; -import org.apache.ignite.resources.IgniteInstanceResource; -import org.apache.ignite.thread.IgniteThread; -import org.apache.ignite.yardstick.IgniteAbstractBenchmark; - -/** - * Ignite benchmark that performs ML Grid operations. - */ -@SuppressWarnings("unused") -public class IgniteColumnDecisionTreeVarianceBenchmark extends IgniteAbstractBenchmark { - /** */ - @IgniteInstanceResource - private Ignite ignite; - - /** {@inheritDoc} */ - @Override public boolean test(Map<Object, Object> ctx) throws Exception { - // Create IgniteThread, we must work with SparseDistributedMatrix inside IgniteThread - // because we create ignite cache internally. - IgniteThread igniteThread = new IgniteThread(ignite.configuration().getIgniteInstanceName(), - this.getClass().getSimpleName(), new Runnable() { - /** {@inheritDoc} */ - @Override public void run() { - // IMPL NOTE originally taken from ColumnDecisionTreeTrainerTest#testCacheMixed - int totalPts = 1 << 10; - int featCnt = 2; - - HashMap<Integer, Integer> catsInfo = new HashMap<>(); - catsInfo.put(1, 3); - - SplitDataGenerator<DenseLocalOnHeapVector> gen - = new SplitDataGenerator<>( - featCnt, catsInfo, () -> new DenseLocalOnHeapVector(featCnt + 1)). - split(0, 1, new int[] {0, 2}). - split(1, 0, -10.0); - - gen.testByGen(totalPts, - ContinuousSplitCalculators.VARIANCE, RegionCalculators.VARIANCE, RegionCalculators.MEAN, ignite); - } - }); - - igniteThread.start(); - - igniteThread.join(); - - return true; - } -}