http://git-wip-us.apache.org/repos/asf/ignite/blob/9abfee69/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/contsplitcalcs/VarianceSplitCalculator.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/contsplitcalcs/VarianceSplitCalculator.java b/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/contsplitcalcs/VarianceSplitCalculator.java deleted file mode 100644 index 66c54f2..0000000 --- a/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/contsplitcalcs/VarianceSplitCalculator.java +++ /dev/null @@ -1,179 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.ignite.ml.trees.trainers.columnbased.contsplitcalcs; - -import java.io.IOException; -import java.io.ObjectInput; -import java.io.ObjectOutput; -import java.util.PrimitiveIterator; -import java.util.stream.DoubleStream; -import org.apache.ignite.ml.trees.ContinuousRegionInfo; -import org.apache.ignite.ml.trees.ContinuousSplitCalculator; -import org.apache.ignite.ml.trees.trainers.columnbased.vectors.ContinuousSplitInfo; -import org.apache.ignite.ml.trees.trainers.columnbased.vectors.SplitInfo; - -/** - * Calculator of variance in a given region. - */ -public class VarianceSplitCalculator implements ContinuousSplitCalculator<VarianceSplitCalculator.VarianceData> { - /** - * Data used in variance calculations. - */ - public static class VarianceData extends ContinuousRegionInfo { - /** Mean value in a given region. */ - double mean; - - /** - * @param var Variance in this region. - * @param size Size of data for which variance is calculated. - * @param mean Mean value in this region. - */ - public VarianceData(double var, int size, double mean) { - super(var, size); - this.mean = mean; - } - - /** - * No-op constructor. For serialization/deserialization. - */ - public VarianceData() { - // No-op. - } - - /** {@inheritDoc} */ - @Override public void writeExternal(ObjectOutput out) throws IOException { - super.writeExternal(out); - out.writeDouble(mean); - } - - /** {@inheritDoc} */ - @Override public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException { - super.readExternal(in); - mean = in.readDouble(); - } - - /** - * Returns mean. - */ - public double mean() { - return mean; - } - } - - /** {@inheritDoc} */ - @Override public VarianceData calculateRegionInfo(DoubleStream s, int size) { - PrimitiveIterator.OfDouble itr = s.iterator(); - int i = 0; - - double mean = 0.0; - double m2 = 0.0; - - // Here we calculate variance and mean by incremental computation. - while (itr.hasNext()) { - i++; - double x = itr.next(); - double delta = x - mean; - mean += delta / i; - double delta2 = x - mean; - m2 += delta * delta2; - } - - return new VarianceData(m2 / i, size, mean); - } - - /** {@inheritDoc} */ - @Override public SplitInfo<VarianceData> splitRegion(Integer[] s, double[] values, double[] labels, int regionIdx, - VarianceData d) { - int size = d.getSize(); - - double lm2 = 0.0; - double rm2 = d.impurity() * size; - int lSize = size; - - double lMean = 0.0; - double rMean = d.mean; - - double minImpurity = d.impurity() * size; - double curThreshold; - double curImpurity; - double threshold = Double.NEGATIVE_INFINITY; - - int i = 0; - int nextIdx = s[0]; - i++; - double[] lrImps = new double[] {lm2, rm2, lMean, rMean}; - - do { - // Process all values equal to prev. - while (i < s.length) { - moveLeft(labels[nextIdx], lrImps[2], i, lrImps[0], lrImps[3], size - i, lrImps[1], lrImps); - curImpurity = (lrImps[0] + lrImps[1]); - curThreshold = values[nextIdx]; - - if (values[nextIdx] != values[(nextIdx = s[i++])]) { - if (curImpurity < minImpurity) { - lSize = i - 1; - - lm2 = lrImps[0]; - rm2 = lrImps[1]; - - lMean = lrImps[2]; - rMean = lrImps[3]; - - minImpurity = curImpurity; - threshold = curThreshold; - } - - break; - } - } - } - while (i < s.length - 1); - - if (lSize == size) - return null; - - VarianceData lData = new VarianceData(lm2 / (lSize != 0 ? lSize : 1), lSize, lMean); - int rSize = size - lSize; - VarianceData rData = new VarianceData(rm2 / (rSize != 0 ? rSize : 1), rSize, rMean); - - return new ContinuousSplitInfo<>(regionIdx, threshold, lData, rData); - } - - /** - * Add point to the left interval and remove it from the right interval and calculate necessary statistics on - * intervals with new bounds. - */ - private void moveLeft(double x, double lMean, int lSize, double lm2, double rMean, int rSize, double rm2, - double[] data) { - // We add point to the left interval. - double lDelta = x - lMean; - double lMeanNew = lMean + lDelta / lSize; - double lm2New = lm2 + lDelta * (x - lMeanNew); - - // We remove point from the right interval. lSize + 1 is the size of right interval before removal. - double rMeanNew = (rMean * (rSize + 1) - x) / rSize; - double rm2New = rm2 - (x - rMean) * (x - rMeanNew); - - data[0] = lm2New; - data[1] = rm2New; - - data[2] = lMeanNew; - data[3] = rMeanNew; - } -}
http://git-wip-us.apache.org/repos/asf/ignite/blob/9abfee69/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/contsplitcalcs/package-info.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/contsplitcalcs/package-info.java b/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/contsplitcalcs/package-info.java deleted file mode 100644 index 08c8a75..0000000 --- a/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/contsplitcalcs/package-info.java +++ /dev/null @@ -1,22 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/** - * <!-- Package description. --> - * Calculators of splits by continuous features. - */ -package org.apache.ignite.ml.trees.trainers.columnbased.contsplitcalcs; \ No newline at end of file http://git-wip-us.apache.org/repos/asf/ignite/blob/9abfee69/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/package-info.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/package-info.java b/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/package-info.java deleted file mode 100644 index 8523914..0000000 --- a/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/package-info.java +++ /dev/null @@ -1,22 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/** - * <!-- Package description. --> - * Contains column based decision tree algorithms. - */ -package org.apache.ignite.ml.trees.trainers.columnbased; \ No newline at end of file http://git-wip-us.apache.org/repos/asf/ignite/blob/9abfee69/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/regcalcs/RegionCalculators.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/regcalcs/RegionCalculators.java b/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/regcalcs/RegionCalculators.java deleted file mode 100644 index 5c4b354..0000000 --- a/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/regcalcs/RegionCalculators.java +++ /dev/null @@ -1,85 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.ignite.ml.trees.trainers.columnbased.regcalcs; - -import it.unimi.dsi.fastutil.doubles.Double2IntOpenHashMap; -import java.util.Comparator; -import java.util.HashMap; -import java.util.Map; -import java.util.PrimitiveIterator; -import java.util.stream.DoubleStream; -import org.apache.ignite.ml.math.functions.IgniteFunction; -import org.apache.ignite.ml.trees.trainers.columnbased.ColumnDecisionTreeTrainerInput; - -/** Some commonly used functions for calculations of regions of space which correspond to decision tree leaf nodes. */ -public class RegionCalculators { - /** Mean value in the region. */ - public static final IgniteFunction<DoubleStream, Double> MEAN = s -> s.average().orElse(0.0); - - /** Most common value in the region. */ - public static final IgniteFunction<DoubleStream, Double> MOST_COMMON = - s -> { - PrimitiveIterator.OfDouble itr = s.iterator(); - Map<Double, Integer> voc = new HashMap<>(); - - while (itr.hasNext()) - voc.compute(itr.next(), (d, i) -> i != null ? i + 1 : 0); - - return voc.entrySet().stream().max(Comparator.comparing(Map.Entry::getValue)).map(Map.Entry::getKey).orElse(0.0); - }; - - /** Variance of a region. */ - public static final IgniteFunction<ColumnDecisionTreeTrainerInput, IgniteFunction<DoubleStream, Double>> VARIANCE = input -> - s -> { - PrimitiveIterator.OfDouble itr = s.iterator(); - int i = 0; - - double mean = 0.0; - double m2 = 0.0; - - while (itr.hasNext()) { - i++; - double x = itr.next(); - double delta = x - mean; - mean += delta / i; - double delta2 = x - mean; - m2 += delta * delta2; - } - - return i > 0 ? m2 / i : 0.0; - }; - - /** Gini impurity of a region. */ - public static final IgniteFunction<ColumnDecisionTreeTrainerInput, IgniteFunction<DoubleStream, Double>> GINI = input -> - s -> { - PrimitiveIterator.OfDouble itr = s.iterator(); - - Double2IntOpenHashMap m = new Double2IntOpenHashMap(); - - int size = 0; - - while (itr.hasNext()) { - size++; - m.compute(itr.next(), (a, i) -> i != null ? i + 1 : 1); - } - - double c2 = m.values().stream().mapToDouble(v -> v * v).sum(); - - return size != 0 ? 1 - c2 / (size * size) : 0.0; - }; -} http://git-wip-us.apache.org/repos/asf/ignite/blob/9abfee69/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/regcalcs/package-info.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/regcalcs/package-info.java b/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/regcalcs/package-info.java deleted file mode 100644 index e8edd8f..0000000 --- a/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/regcalcs/package-info.java +++ /dev/null @@ -1,22 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/** - * <!-- Package description. --> - * Region calculators. - */ -package org.apache.ignite.ml.trees.trainers.columnbased.regcalcs; \ No newline at end of file http://git-wip-us.apache.org/repos/asf/ignite/blob/9abfee69/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/vectors/CategoricalFeatureProcessor.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/vectors/CategoricalFeatureProcessor.java b/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/vectors/CategoricalFeatureProcessor.java deleted file mode 100644 index 3232ac2..0000000 --- a/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/vectors/CategoricalFeatureProcessor.java +++ /dev/null @@ -1,212 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.ignite.ml.trees.trainers.columnbased.vectors; - -import com.zaxxer.sparsebits.SparseBitSet; -import java.util.Arrays; -import java.util.BitSet; -import java.util.Comparator; -import java.util.HashMap; -import java.util.Iterator; -import java.util.List; -import java.util.Map; -import java.util.stream.Collectors; -import java.util.stream.DoubleStream; -import java.util.stream.Stream; -import java.util.stream.StreamSupport; -import org.apache.ignite.lang.IgniteBiTuple; -import org.apache.ignite.ml.math.functions.IgniteFunction; -import org.apache.ignite.ml.trees.CategoricalRegionInfo; -import org.apache.ignite.ml.trees.CategoricalSplitInfo; -import org.apache.ignite.ml.trees.RegionInfo; -import org.apache.ignite.ml.trees.trainers.columnbased.ColumnDecisionTreeTrainer; -import org.apache.ignite.ml.trees.trainers.columnbased.RegionProjection; - -import static org.apache.ignite.ml.trees.trainers.columnbased.vectors.FeatureVectorProcessorUtils.splitByBitSet; - -/** - * Categorical feature vector processor implementation used by {@link ColumnDecisionTreeTrainer}. - */ -public class CategoricalFeatureProcessor - implements FeatureProcessor<CategoricalRegionInfo, CategoricalSplitInfo<CategoricalRegionInfo>> { - /** Count of categories for this feature. */ - private final int catsCnt; - - /** Function for calculating impurity of a given region of points. */ - private final IgniteFunction<DoubleStream, Double> calc; - - /** - * @param calc Function for calculating impurity of a given region of points. - * @param catsCnt Number of categories. - */ - public CategoricalFeatureProcessor(IgniteFunction<DoubleStream, Double> calc, int catsCnt) { - this.calc = calc; - this.catsCnt = catsCnt; - } - - /** */ - private SplitInfo<CategoricalRegionInfo> split(BitSet leftCats, int intervalIdx, Map<Integer, Integer> mapping, - Integer[] sampleIndexes, double[] values, double[] labels, double impurity) { - Map<Boolean, List<Integer>> leftRight = Arrays.stream(sampleIndexes). - collect(Collectors.partitioningBy((smpl) -> leftCats.get(mapping.get((int)values[smpl])))); - - List<Integer> left = leftRight.get(true); - int leftSize = left.size(); - double leftImpurity = calc.apply(left.stream().mapToDouble(s -> labels[s])); - - List<Integer> right = leftRight.get(false); - int rightSize = right.size(); - double rightImpurity = calc.apply(right.stream().mapToDouble(s -> labels[s])); - - int totalSize = leftSize + rightSize; - - // Result of this call will be sent back to trainer node, we do not need vectors inside of sent data. - CategoricalSplitInfo<CategoricalRegionInfo> res = new CategoricalSplitInfo<>(intervalIdx, - new CategoricalRegionInfo(leftImpurity, null), // cats can be computed on the last step. - new CategoricalRegionInfo(rightImpurity, null), - leftCats); - - res.setInfoGain(impurity - (double)leftSize / totalSize * leftImpurity - (double)rightSize / totalSize * rightImpurity); - return res; - } - - /** - * Get a stream of subsets given categories count. - * - * @param catsCnt categories count. - * @return Stream of subsets given categories count. - */ - private Stream<BitSet> powerSet(int catsCnt) { - Iterable<BitSet> iterable = () -> new PSI(catsCnt); - return StreamSupport.stream(iterable.spliterator(), false); - } - - /** {@inheritDoc} */ - @Override public SplitInfo findBestSplit(RegionProjection<CategoricalRegionInfo> regionPrj, double[] values, - double[] labels, int regIdx) { - Map<Integer, Integer> mapping = mapping(regionPrj.data().cats()); - - return powerSet(regionPrj.data().cats().length()). - map(s -> split(s, regIdx, mapping, regionPrj.sampleIndexes(), values, labels, regionPrj.data().impurity())). - max(Comparator.comparingDouble(SplitInfo::infoGain)). - orElse(null); - } - - /** {@inheritDoc} */ - @Override public RegionProjection<CategoricalRegionInfo> createInitialRegion(Integer[] sampleIndexes, - double[] values, double[] labels) { - BitSet set = new BitSet(); - set.set(0, catsCnt); - - Double impurity = calc.apply(Arrays.stream(labels)); - - return new RegionProjection<>(sampleIndexes, new CategoricalRegionInfo(impurity, set), 0); - } - - /** {@inheritDoc} */ - @Override public SparseBitSet calculateOwnershipBitSet(RegionProjection<CategoricalRegionInfo> regionPrj, - double[] values, - CategoricalSplitInfo<CategoricalRegionInfo> s) { - SparseBitSet res = new SparseBitSet(); - Arrays.stream(regionPrj.sampleIndexes()).forEach(smpl -> res.set(smpl, s.bitSet().get((int)values[smpl]))); - return res; - } - - /** {@inheritDoc} */ - @Override public IgniteBiTuple<RegionProjection, RegionProjection> performSplit(SparseBitSet bs, - RegionProjection<CategoricalRegionInfo> reg, CategoricalRegionInfo leftData, CategoricalRegionInfo rightData) { - return performSplitGeneric(bs, null, reg, leftData, rightData); - } - - /** {@inheritDoc} */ - @Override public IgniteBiTuple<RegionProjection, RegionProjection> performSplitGeneric( - SparseBitSet bs, double[] values, RegionProjection<CategoricalRegionInfo> reg, RegionInfo leftData, - RegionInfo rightData) { - int depth = reg.depth(); - - int lSize = bs.cardinality(); - int rSize = reg.sampleIndexes().length - lSize; - IgniteBiTuple<Integer[], Integer[]> lrSamples = splitByBitSet(lSize, rSize, reg.sampleIndexes(), bs); - BitSet leftCats = calculateCats(lrSamples.get1(), values); - CategoricalRegionInfo lInfo = new CategoricalRegionInfo(leftData.impurity(), leftCats); - - // TODO: IGNITE-5892 Check how it will work with sparse data. - BitSet rightCats = calculateCats(lrSamples.get2(), values); - CategoricalRegionInfo rInfo = new CategoricalRegionInfo(rightData.impurity(), rightCats); - - RegionProjection<CategoricalRegionInfo> rPrj = new RegionProjection<>(lrSamples.get2(), rInfo, depth + 1); - RegionProjection<CategoricalRegionInfo> lPrj = new RegionProjection<>(lrSamples.get1(), lInfo, depth + 1); - return new IgniteBiTuple<>(lPrj, rPrj); - } - - /** - * Powerset iterator. Iterates not over the whole powerset, but on half of it. - */ - private static class PSI implements Iterator<BitSet> { - - /** Current subset number. */ - private int i = 1; // We are not interested in {emptyset, set} split and therefore start from 1. - - /** Size of set, subsets of which we iterate over. */ - final int size; - - /** - * @param bitCnt Size of set, subsets of which we iterate over. - */ - PSI(int bitCnt) { - this.size = 1 << (bitCnt - 1); - } - - /** {@inheritDoc} */ - @Override public boolean hasNext() { - return i < size; - } - - /** {@inheritDoc} */ - @Override public BitSet next() { - BitSet res = BitSet.valueOf(new long[] {i}); - i++; - return res; - } - } - - /** */ - private Map<Integer, Integer> mapping(BitSet bs) { - int bn = 0; - Map<Integer, Integer> res = new HashMap<>(); - - int i = 0; - while ((bn = bs.nextSetBit(bn)) != -1) { - res.put(bn, i); - i++; - bn++; - } - - return res; - } - - /** Get set of categories of given samples */ - private BitSet calculateCats(Integer[] sampleIndexes, double[] values) { - BitSet res = new BitSet(); - - for (int smpl : sampleIndexes) - res.set((int)values[smpl]); - - return res; - } -} http://git-wip-us.apache.org/repos/asf/ignite/blob/9abfee69/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/vectors/ContinuousFeatureProcessor.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/vectors/ContinuousFeatureProcessor.java b/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/vectors/ContinuousFeatureProcessor.java deleted file mode 100644 index 4117993..0000000 --- a/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/vectors/ContinuousFeatureProcessor.java +++ /dev/null @@ -1,111 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.ignite.ml.trees.trainers.columnbased.vectors; - -import com.zaxxer.sparsebits.SparseBitSet; -import java.util.Arrays; -import java.util.Comparator; -import org.apache.ignite.lang.IgniteBiTuple; -import org.apache.ignite.ml.trees.ContinuousRegionInfo; -import org.apache.ignite.ml.trees.ContinuousSplitCalculator; -import org.apache.ignite.ml.trees.RegionInfo; -import org.apache.ignite.ml.trees.trainers.columnbased.RegionProjection; - -import static org.apache.ignite.ml.trees.trainers.columnbased.vectors.FeatureVectorProcessorUtils.splitByBitSet; - -/** - * Container of projection of samples on continuous feature. - * - * @param <D> Information about regions. Designed to contain information which will make computations of impurity - * optimal. - */ -public class ContinuousFeatureProcessor<D extends ContinuousRegionInfo> implements - FeatureProcessor<D, ContinuousSplitInfo<D>> { - /** ContinuousSplitCalculator used for calculating of best split of each region. */ - private final ContinuousSplitCalculator<D> calc; - - /** - * @param splitCalc Calculator used for calculating splits. - */ - public ContinuousFeatureProcessor(ContinuousSplitCalculator<D> splitCalc) { - this.calc = splitCalc; - } - - /** {@inheritDoc} */ - @Override public SplitInfo<D> findBestSplit(RegionProjection<D> ri, double[] values, double[] labels, int regIdx) { - SplitInfo<D> res = calc.splitRegion(ri.sampleIndexes(), values, labels, regIdx, ri.data()); - - if (res == null) - return null; - - double lWeight = (double)res.leftData.getSize() / ri.sampleIndexes().length; - double rWeight = (double)res.rightData.getSize() / ri.sampleIndexes().length; - - double infoGain = ri.data().impurity() - lWeight * res.leftData().impurity() - rWeight * res.rightData().impurity(); - res.setInfoGain(infoGain); - - return res; - } - - /** {@inheritDoc} */ - @Override public RegionProjection<D> createInitialRegion(Integer[] samples, double[] values, double[] labels) { - Arrays.sort(samples, Comparator.comparingDouble(s -> values[s])); - return new RegionProjection<>(samples, calc.calculateRegionInfo(Arrays.stream(labels), samples.length), 0); - } - - /** {@inheritDoc} */ - @Override public SparseBitSet calculateOwnershipBitSet(RegionProjection<D> reg, double[] values, - ContinuousSplitInfo<D> s) { - SparseBitSet res = new SparseBitSet(); - - for (int i = 0; i < s.leftData().getSize(); i++) - res.set(reg.sampleIndexes()[i]); - - return res; - } - - /** {@inheritDoc} */ - @Override public IgniteBiTuple<RegionProjection, RegionProjection> performSplit(SparseBitSet bs, - RegionProjection<D> reg, D leftData, D rightData) { - int lSize = leftData.getSize(); - int rSize = rightData.getSize(); - int depth = reg.depth(); - - IgniteBiTuple<Integer[], Integer[]> lrSamples = splitByBitSet(lSize, rSize, reg.sampleIndexes(), bs); - - RegionProjection<D> left = new RegionProjection<>(lrSamples.get1(), leftData, depth + 1); - RegionProjection<D> right = new RegionProjection<>(lrSamples.get2(), rightData, depth + 1); - - return new IgniteBiTuple<>(left, right); - } - - /** {@inheritDoc} */ - @Override public IgniteBiTuple<RegionProjection, RegionProjection> performSplitGeneric(SparseBitSet bs, - double[] labels, RegionProjection<D> reg, RegionInfo leftData, RegionInfo rightData) { - int lSize = bs.cardinality(); - int rSize = reg.sampleIndexes().length - lSize; - int depth = reg.depth(); - - IgniteBiTuple<Integer[], Integer[]> lrSamples = splitByBitSet(lSize, rSize, reg.sampleIndexes(), bs); - - D ld = calc.calculateRegionInfo(Arrays.stream(lrSamples.get1()).mapToDouble(s -> labels[s]), lSize); - D rd = calc.calculateRegionInfo(Arrays.stream(lrSamples.get2()).mapToDouble(s -> labels[s]), rSize); - - return new IgniteBiTuple<>(new RegionProjection<>(lrSamples.get1(), ld, depth + 1), new RegionProjection<>(lrSamples.get2(), rd, depth + 1)); - } -} http://git-wip-us.apache.org/repos/asf/ignite/blob/9abfee69/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/vectors/ContinuousSplitInfo.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/vectors/ContinuousSplitInfo.java b/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/vectors/ContinuousSplitInfo.java deleted file mode 100644 index 8b45cb5..0000000 --- a/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/vectors/ContinuousSplitInfo.java +++ /dev/null @@ -1,71 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.ignite.ml.trees.trainers.columnbased.vectors; - -import org.apache.ignite.ml.trees.RegionInfo; -import org.apache.ignite.ml.trees.nodes.ContinuousSplitNode; -import org.apache.ignite.ml.trees.nodes.SplitNode; - -/** - * Information about split of continuous region. - * - * @param <D> Class encapsulating information about the region. - */ -public class ContinuousSplitInfo<D extends RegionInfo> extends SplitInfo<D> { - /** - * Threshold used for split. - * Samples with values less or equal than this go to left region, others go to the right region. - */ - private final double threshold; - - /** - * @param regionIdx Index of region being split. - * @param threshold Threshold used for split. Samples with values less or equal than this go to left region, others - * go to the right region. - * @param leftData Information about left subregion. - * @param rightData Information about right subregion. - */ - public ContinuousSplitInfo(int regionIdx, double threshold, D leftData, D rightData) { - super(regionIdx, leftData, rightData); - this.threshold = threshold; - } - - /** {@inheritDoc} */ - @Override public SplitNode createSplitNode(int featureIdx) { - return new ContinuousSplitNode(threshold, featureIdx); - } - - /** - * Threshold used for splits. - * Samples with values less or equal than this go to left region, others go to the right region. - */ - public double threshold() { - return threshold; - } - - /** {@inheritDoc} */ - @Override public String toString() { - return "ContinuousSplitInfo [" + - "threshold=" + threshold + - ", infoGain=" + infoGain + - ", regionIdx=" + regionIdx + - ", leftData=" + leftData + - ", rightData=" + rightData + - ']'; - } -} http://git-wip-us.apache.org/repos/asf/ignite/blob/9abfee69/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/vectors/FeatureProcessor.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/vectors/FeatureProcessor.java b/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/vectors/FeatureProcessor.java deleted file mode 100644 index 56508e5..0000000 --- a/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/vectors/FeatureProcessor.java +++ /dev/null @@ -1,82 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.ignite.ml.trees.trainers.columnbased.vectors; - -import com.zaxxer.sparsebits.SparseBitSet; -import org.apache.ignite.lang.IgniteBiTuple; -import org.apache.ignite.ml.trees.RegionInfo; -import org.apache.ignite.ml.trees.trainers.columnbased.ColumnDecisionTreeTrainer; -import org.apache.ignite.ml.trees.trainers.columnbased.RegionProjection; - -/** - * Base interface for feature processors used in {@link ColumnDecisionTreeTrainer} - * - * @param <D> Class representing data of regions resulted from split. - * @param <S> Class representing data of split. - */ -public interface FeatureProcessor<D extends RegionInfo, S extends SplitInfo<D>> { - /** - * Finds best split by this feature among all splits of all regions. - * - * @return best split by this feature among all splits of all regions. - */ - SplitInfo findBestSplit(RegionProjection<D> regionPrj, double[] values, double[] labels, int regIdx); - - /** - * Creates initial region from samples. - * - * @param samples samples. - * @return region. - */ - RegionProjection<D> createInitialRegion(Integer[] samples, double[] values, double[] labels); - - /** - * Calculates the bitset mapping each data point to left (corresponding bit is set) or right subregion. - * - * @param s data used for calculating the split. - * @return Bitset mapping each data point to left (corresponding bit is set) or right subregion. - */ - SparseBitSet calculateOwnershipBitSet(RegionProjection<D> regionPrj, double[] values, S s); - - /** - * Splits given region using bitset which maps data point to left or right subregion. - * This method is present for the vectors of the same type to be able to pass between them information about regions - * and therefore used iff the optimal split is received on feature of the same type. - * - * @param bs Bitset which maps data point to left or right subregion. - * @param leftData Data of the left subregion. - * @param rightData Data of the right subregion. - * @return This feature vector. - */ - IgniteBiTuple<RegionProjection, RegionProjection> performSplit(SparseBitSet bs, RegionProjection<D> reg, D leftData, - D rightData); - - /** - * Splits given region using bitset which maps data point to left or right subregion. This method is used iff the - * optimal split is received on feature of different type, therefore information about regions is limited to the - * {@link RegionInfo} class which is base for all classes used to represent region data. - * - * @param bs Bitset which maps data point to left or right subregion. - * @param leftData Data of the left subregion. - * @param rightData Data of the right subregion. - * @return This feature vector. - */ - IgniteBiTuple<RegionProjection, RegionProjection> performSplitGeneric(SparseBitSet bs, double[] values, - RegionProjection<D> reg, RegionInfo leftData, - RegionInfo rightData); -} http://git-wip-us.apache.org/repos/asf/ignite/blob/9abfee69/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/vectors/FeatureVectorProcessorUtils.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/vectors/FeatureVectorProcessorUtils.java b/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/vectors/FeatureVectorProcessorUtils.java deleted file mode 100644 index 69ff019..0000000 --- a/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/vectors/FeatureVectorProcessorUtils.java +++ /dev/null @@ -1,57 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.ignite.ml.trees.trainers.columnbased.vectors; - -import com.zaxxer.sparsebits.SparseBitSet; -import org.apache.ignite.lang.IgniteBiTuple; - -/** Utility class for feature vector processors. */ -public class FeatureVectorProcessorUtils { - /** - * Split target array into two (left and right) arrays by bitset. - * - * @param lSize Left array size; - * @param rSize Right array size. - * @param samples Arrays to split size. - * @param bs Bitset specifying split. - * @return BiTuple containing result of split. - */ - public static IgniteBiTuple<Integer[], Integer[]> splitByBitSet(int lSize, int rSize, Integer[] samples, - SparseBitSet bs) { - Integer[] lArr = new Integer[lSize]; - Integer[] rArr = new Integer[rSize]; - - int lc = 0; - int rc = 0; - - for (int i = 0; i < lSize + rSize; i++) { - int si = samples[i]; - - if (bs.get(si)) { - lArr[lc] = si; - lc++; - } - else { - rArr[rc] = si; - rc++; - } - } - - return new IgniteBiTuple<>(lArr, rArr); - } -} http://git-wip-us.apache.org/repos/asf/ignite/blob/9abfee69/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/vectors/SampleInfo.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/vectors/SampleInfo.java b/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/vectors/SampleInfo.java deleted file mode 100644 index 8aa4f79..0000000 --- a/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/vectors/SampleInfo.java +++ /dev/null @@ -1,80 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.ignite.ml.trees.trainers.columnbased.vectors; - -import java.io.Externalizable; -import java.io.IOException; -import java.io.ObjectInput; -import java.io.ObjectOutput; - -/** - * Information about given sample within given fixed feature. - */ -public class SampleInfo implements Externalizable { - /** Value of projection of this sample on given fixed feature. */ - private double val; - - /** Sample index. */ - private int sampleIdx; - - /** - * @param val Value of projection of this sample on given fixed feature. - * @param sampleIdx Sample index. - */ - public SampleInfo(double val, int sampleIdx) { - this.val = val; - this.sampleIdx = sampleIdx; - } - - /** - * No-op constructor used for serialization/deserialization. - */ - public SampleInfo() { - // No-op. - } - - /** - * Get the value of projection of this sample on given fixed feature. - * - * @return Value of projection of this sample on given fixed feature. - */ - public double val() { - return val; - } - - /** - * Get the sample index. - * - * @return Sample index. - */ - public int sampleInd() { - return sampleIdx; - } - - /** {@inheritDoc} */ - @Override public void writeExternal(ObjectOutput out) throws IOException { - out.writeDouble(val); - out.writeInt(sampleIdx); - } - - /** {@inheritDoc} */ - @Override public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException { - val = in.readDouble(); - sampleIdx = in.readInt(); - } -} http://git-wip-us.apache.org/repos/asf/ignite/blob/9abfee69/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/vectors/SplitInfo.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/vectors/SplitInfo.java b/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/vectors/SplitInfo.java deleted file mode 100644 index 124e82f..0000000 --- a/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/vectors/SplitInfo.java +++ /dev/null @@ -1,106 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.ignite.ml.trees.trainers.columnbased.vectors; - -import org.apache.ignite.ml.trees.RegionInfo; -import org.apache.ignite.ml.trees.nodes.SplitNode; - -/** - * Class encapsulating information about the split. - * - * @param <D> Class representing information of left and right subregions. - */ -public abstract class SplitInfo<D extends RegionInfo> { - /** Information gain of this split. */ - protected double infoGain; - - /** Index of the region to split. */ - protected final int regionIdx; - - /** Data of left subregion. */ - protected final D leftData; - - /** Data of right subregion. */ - protected final D rightData; - - /** - * Construct the split info. - * - * @param regionIdx Index of the region to split. - * @param leftData Data of left subregion. - * @param rightData Data of right subregion. - */ - public SplitInfo(int regionIdx, D leftData, D rightData) { - this.regionIdx = regionIdx; - this.leftData = leftData; - this.rightData = rightData; - } - - /** - * Index of region to split. - * - * @return Index of region to split. - */ - public int regionIndex() { - return regionIdx; - } - - /** - * Information gain of the split. - * - * @return Information gain of the split. - */ - public double infoGain() { - return infoGain; - } - - /** - * Data of right subregion. - * - * @return Data of right subregion. - */ - public D rightData() { - return rightData; - } - - /** - * Data of left subregion. - * - * @return Data of left subregion. - */ - public D leftData() { - return leftData; - } - - /** - * Create SplitNode from this split info. - * - * @param featureIdx Index of feature by which goes split. - * @return SplitNode from this split info. - */ - public abstract SplitNode createSplitNode(int featureIdx); - - /** - * Set information gain. - * - * @param infoGain Information gain. - */ - public void setInfoGain(double infoGain) { - this.infoGain = infoGain; - } -} http://git-wip-us.apache.org/repos/asf/ignite/blob/9abfee69/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/vectors/package-info.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/vectors/package-info.java b/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/vectors/package-info.java deleted file mode 100644 index 0dea204..0000000 --- a/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/vectors/package-info.java +++ /dev/null @@ -1,22 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/** - * <!-- Package description. --> - * Contains feature containers needed by column based decision tree trainers. - */ -package org.apache.ignite.ml.trees.trainers.columnbased.vectors; \ No newline at end of file http://git-wip-us.apache.org/repos/asf/ignite/blob/9abfee69/modules/ml/src/test/java/org/apache/ignite/ml/IgniteMLTestSuite.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/IgniteMLTestSuite.java b/modules/ml/src/test/java/org/apache/ignite/ml/IgniteMLTestSuite.java index e22a3a5..9900f85 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/IgniteMLTestSuite.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/IgniteMLTestSuite.java @@ -28,7 +28,7 @@ import org.apache.ignite.ml.preprocessing.PreprocessingTestSuite; import org.apache.ignite.ml.regressions.RegressionsTestSuite; import org.apache.ignite.ml.svm.SVMTestSuite; import org.apache.ignite.ml.trainers.group.TrainersGroupTestSuite; -import org.apache.ignite.ml.trees.DecisionTreesTestSuite; +import org.apache.ignite.ml.tree.DecisionTreeTestSuite; import org.junit.runner.RunWith; import org.junit.runners.Suite; @@ -41,7 +41,7 @@ import org.junit.runners.Suite; RegressionsTestSuite.class, SVMTestSuite.class, ClusteringTestSuite.class, - DecisionTreesTestSuite.class, + DecisionTreeTestSuite.class, KNNTestSuite.class, LocalModelsTest.class, MLPTestSuite.class, http://git-wip-us.apache.org/repos/asf/ignite/blob/9abfee69/modules/ml/src/test/java/org/apache/ignite/ml/nn/performance/MnistMLPTestUtil.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/nn/performance/MnistMLPTestUtil.java b/modules/ml/src/test/java/org/apache/ignite/ml/nn/performance/MnistMLPTestUtil.java index e624004..d68b355 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/nn/performance/MnistMLPTestUtil.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/nn/performance/MnistMLPTestUtil.java @@ -25,11 +25,10 @@ import java.util.Random; import java.util.stream.Stream; import org.apache.ignite.lang.IgniteBiTuple; import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector; -import org.apache.ignite.ml.trees.performance.ColumnDecisionTreeTrainerBenchmark; import org.apache.ignite.ml.util.MnistUtils; /** */ -class MnistMLPTestUtil { +public class MnistMLPTestUtil { /** Name of the property specifying path to training set images. */ private static final String PROP_TRAINING_IMAGES = "mnist.training.images"; @@ -62,7 +61,7 @@ class MnistMLPTestUtil { * @return List of MNIST images. * @throws IOException In case of exception. */ - static List<MnistUtils.MnistLabeledImage> loadTrainingSet(int cnt) throws IOException { + public static List<MnistUtils.MnistLabeledImage> loadTrainingSet(int cnt) throws IOException { Properties props = loadMNISTProperties(); return MnistUtils.mnistAsList(props.getProperty(PROP_TRAINING_IMAGES), props.getProperty(PROP_TRAINING_LABELS), new Random(123L), cnt); } @@ -74,7 +73,7 @@ class MnistMLPTestUtil { * @return List of MNIST images. * @throws IOException In case of exception. */ - static List<MnistUtils.MnistLabeledImage> loadTestSet(int cnt) throws IOException { + public static List<MnistUtils.MnistLabeledImage> loadTestSet(int cnt) throws IOException { Properties props = loadMNISTProperties(); return MnistUtils.mnistAsList(props.getProperty(PROP_TEST_IMAGES), props.getProperty(PROP_TEST_LABELS), new Random(123L), cnt); } @@ -83,7 +82,7 @@ class MnistMLPTestUtil { private static Properties loadMNISTProperties() throws IOException { Properties res = new Properties(); - InputStream is = ColumnDecisionTreeTrainerBenchmark.class.getClassLoader().getResourceAsStream("manualrun/trees/columntrees.manualrun.properties"); + InputStream is = MnistMLPTestUtil.class.getClassLoader().getResourceAsStream("manualrun/trees/columntrees.manualrun.properties"); res.load(is); http://git-wip-us.apache.org/repos/asf/ignite/blob/9abfee69/modules/ml/src/test/java/org/apache/ignite/ml/tree/DecisionTreeClassificationTrainerIntegrationTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/tree/DecisionTreeClassificationTrainerIntegrationTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/tree/DecisionTreeClassificationTrainerIntegrationTest.java new file mode 100644 index 0000000..94bca3f --- /dev/null +++ b/modules/ml/src/test/java/org/apache/ignite/ml/tree/DecisionTreeClassificationTrainerIntegrationTest.java @@ -0,0 +1,100 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.tree; + +import java.util.Arrays; +import java.util.Random; +import org.apache.ignite.Ignite; +import org.apache.ignite.IgniteCache; +import org.apache.ignite.cache.affinity.rendezvous.RendezvousAffinityFunction; +import org.apache.ignite.configuration.CacheConfiguration; +import org.apache.ignite.internal.util.IgniteUtils; +import org.apache.ignite.ml.dataset.impl.cache.CacheBasedDatasetBuilder; +import org.apache.ignite.testframework.junits.common.GridCommonAbstractTest; + +/** + * Tests for {@link DecisionTreeClassificationTrainer} that require to start the whole Ignite infrastructure. + */ +public class DecisionTreeClassificationTrainerIntegrationTest extends GridCommonAbstractTest { + /** Number of nodes in grid */ + private static final int NODE_COUNT = 3; + + /** Ignite instance. */ + private Ignite ignite; + + /** {@inheritDoc} */ + @Override protected void beforeTestsStarted() throws Exception { + for (int i = 1; i <= NODE_COUNT; i++) + startGrid(i); + } + + /** {@inheritDoc} */ + @Override protected void afterTestsStopped() { + stopAllGrids(); + } + + /** + * {@inheritDoc} + */ + @Override protected void beforeTest() throws Exception { + /* Grid instance. */ + ignite = grid(NODE_COUNT); + ignite.configuration().setPeerClassLoadingEnabled(true); + IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName()); + } + + /** */ + public void testFit() { + int size = 100; + + CacheConfiguration<Integer, double[]> trainingSetCacheCfg = new CacheConfiguration<>(); + trainingSetCacheCfg.setAffinity(new RendezvousAffinityFunction(false, 10)); + trainingSetCacheCfg.setName("TRAINING_SET"); + + IgniteCache<Integer, double[]> data = ignite.createCache(trainingSetCacheCfg); + + Random rnd = new Random(0); + for (int i = 0; i < size; i++) { + double x = rnd.nextDouble() - 0.5; + data.put(i, new double[]{x, x > 0 ? 1 : 0}); + } + + DecisionTreeClassificationTrainer trainer = new DecisionTreeClassificationTrainer(1, 0); + + DecisionTreeNode tree = trainer.fit( + new CacheBasedDatasetBuilder<>(ignite, data), + (k, v) -> Arrays.copyOf(v, v.length - 1), + (k, v) -> v[v.length - 1] + ); + + assertTrue(tree instanceof DecisionTreeConditionalNode); + + DecisionTreeConditionalNode node = (DecisionTreeConditionalNode) tree; + + assertEquals(0, node.getThreshold(), 1e-3); + + assertTrue(node.getThenNode() instanceof DecisionTreeLeafNode); + assertTrue(node.getElseNode() instanceof DecisionTreeLeafNode); + + DecisionTreeLeafNode thenNode = (DecisionTreeLeafNode) node.getThenNode(); + DecisionTreeLeafNode elseNode = (DecisionTreeLeafNode) node.getElseNode(); + + assertEquals(1, thenNode.getVal(), 1e-10); + assertEquals(0, elseNode.getVal(), 1e-10); + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/9abfee69/modules/ml/src/test/java/org/apache/ignite/ml/tree/DecisionTreeClassificationTrainerTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/tree/DecisionTreeClassificationTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/tree/DecisionTreeClassificationTrainerTest.java new file mode 100644 index 0000000..2599bfe --- /dev/null +++ b/modules/ml/src/test/java/org/apache/ignite/ml/tree/DecisionTreeClassificationTrainerTest.java @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.tree; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Random; +import org.apache.ignite.ml.dataset.impl.local.LocalDatasetBuilder; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +import static junit.framework.TestCase.assertEquals; +import static junit.framework.TestCase.assertTrue; + +/** + * Tests for {@link DecisionTreeClassificationTrainer}. + */ +@RunWith(Parameterized.class) +public class DecisionTreeClassificationTrainerTest { + /** Number of parts to be tested. */ + private static final int[] partsToBeTested = new int[] {1, 2, 3, 4, 5, 7}; + + /** Number of partitions. */ + @Parameterized.Parameter + public int parts; + + @Parameterized.Parameters(name = "Data divided on {0} partitions") + public static Iterable<Integer[]> data() { + List<Integer[]> res = new ArrayList<>(); + for (int part : partsToBeTested) + res.add(new Integer[] {part}); + + return res; + } + + /** */ + @Test + public void testFit() { + int size = 100; + + Map<Integer, double[]> data = new HashMap<>(); + + Random rnd = new Random(0); + for (int i = 0; i < size; i++) { + double x = rnd.nextDouble() - 0.5; + data.put(i, new double[]{x, x > 0 ? 1 : 0}); + } + + DecisionTreeClassificationTrainer trainer = new DecisionTreeClassificationTrainer(1, 0); + + DecisionTreeNode tree = trainer.fit( + new LocalDatasetBuilder<>(data, parts), + (k, v) -> Arrays.copyOf(v, v.length - 1), + (k, v) -> v[v.length - 1] + ); + + assertTrue(tree instanceof DecisionTreeConditionalNode); + + DecisionTreeConditionalNode node = (DecisionTreeConditionalNode) tree; + + assertEquals(0, node.getThreshold(), 1e-3); + + assertTrue(node.getThenNode() instanceof DecisionTreeLeafNode); + assertTrue(node.getElseNode() instanceof DecisionTreeLeafNode); + + DecisionTreeLeafNode thenNode = (DecisionTreeLeafNode) node.getThenNode(); + DecisionTreeLeafNode elseNode = (DecisionTreeLeafNode) node.getElseNode(); + + assertEquals(1, thenNode.getVal(), 1e-10); + assertEquals(0, elseNode.getVal(), 1e-10); + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/9abfee69/modules/ml/src/test/java/org/apache/ignite/ml/tree/DecisionTreeRegressionTrainerIntegrationTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/tree/DecisionTreeRegressionTrainerIntegrationTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/tree/DecisionTreeRegressionTrainerIntegrationTest.java new file mode 100644 index 0000000..754ff20 --- /dev/null +++ b/modules/ml/src/test/java/org/apache/ignite/ml/tree/DecisionTreeRegressionTrainerIntegrationTest.java @@ -0,0 +1,100 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.tree; + +import java.util.Arrays; +import java.util.Random; +import org.apache.ignite.Ignite; +import org.apache.ignite.IgniteCache; +import org.apache.ignite.cache.affinity.rendezvous.RendezvousAffinityFunction; +import org.apache.ignite.configuration.CacheConfiguration; +import org.apache.ignite.internal.util.IgniteUtils; +import org.apache.ignite.ml.dataset.impl.cache.CacheBasedDatasetBuilder; +import org.apache.ignite.testframework.junits.common.GridCommonAbstractTest; + +/** + * Tests for {@link DecisionTreeRegressionTrainer} that require to start the whole Ignite infrastructure. + */ +public class DecisionTreeRegressionTrainerIntegrationTest extends GridCommonAbstractTest { + /** Number of nodes in grid */ + private static final int NODE_COUNT = 3; + + /** Ignite instance. */ + private Ignite ignite; + + /** {@inheritDoc} */ + @Override protected void beforeTestsStarted() throws Exception { + for (int i = 1; i <= NODE_COUNT; i++) + startGrid(i); + } + + /** {@inheritDoc} */ + @Override protected void afterTestsStopped() { + stopAllGrids(); + } + + /** + * {@inheritDoc} + */ + @Override protected void beforeTest() throws Exception { + /* Grid instance. */ + ignite = grid(NODE_COUNT); + ignite.configuration().setPeerClassLoadingEnabled(true); + IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName()); + } + + /** */ + public void testFit() { + int size = 100; + + CacheConfiguration<Integer, double[]> trainingSetCacheCfg = new CacheConfiguration<>(); + trainingSetCacheCfg.setAffinity(new RendezvousAffinityFunction(false, 10)); + trainingSetCacheCfg.setName("TRAINING_SET"); + + IgniteCache<Integer, double[]> data = ignite.createCache(trainingSetCacheCfg); + + Random rnd = new Random(0); + for (int i = 0; i < size; i++) { + double x = rnd.nextDouble() - 0.5; + data.put(i, new double[]{x, x > 0 ? 1 : 0}); + } + + DecisionTreeRegressionTrainer trainer = new DecisionTreeRegressionTrainer(1, 0); + + DecisionTreeNode tree = trainer.fit( + new CacheBasedDatasetBuilder<>(ignite, data), + (k, v) -> Arrays.copyOf(v, v.length - 1), + (k, v) -> v[v.length - 1] + ); + + assertTrue(tree instanceof DecisionTreeConditionalNode); + + DecisionTreeConditionalNode node = (DecisionTreeConditionalNode) tree; + + assertEquals(0, node.getThreshold(), 1e-3); + + assertTrue(node.getThenNode() instanceof DecisionTreeLeafNode); + assertTrue(node.getElseNode() instanceof DecisionTreeLeafNode); + + DecisionTreeLeafNode thenNode = (DecisionTreeLeafNode) node.getThenNode(); + DecisionTreeLeafNode elseNode = (DecisionTreeLeafNode) node.getElseNode(); + + assertEquals(1, thenNode.getVal(), 1e-10); + assertEquals(0, elseNode.getVal(), 1e-10); + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/9abfee69/modules/ml/src/test/java/org/apache/ignite/ml/tree/DecisionTreeRegressionTrainerTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/tree/DecisionTreeRegressionTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/tree/DecisionTreeRegressionTrainerTest.java new file mode 100644 index 0000000..3bdbf60 --- /dev/null +++ b/modules/ml/src/test/java/org/apache/ignite/ml/tree/DecisionTreeRegressionTrainerTest.java @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.tree; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Random; +import org.apache.ignite.ml.dataset.impl.local.LocalDatasetBuilder; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +import static junit.framework.TestCase.assertEquals; +import static junit.framework.TestCase.assertTrue; + +/** + * Tests for {@link DecisionTreeRegressionTrainer}. + */ +@RunWith(Parameterized.class) +public class DecisionTreeRegressionTrainerTest { + /** Number of parts to be tested. */ + private static final int[] partsToBeTested = new int[] {1, 2, 3, 4, 5, 7}; + + /** Number of partitions. */ + @Parameterized.Parameter + public int parts; + + @Parameterized.Parameters(name = "Data divided on {0} partitions") + public static Iterable<Integer[]> data() { + List<Integer[]> res = new ArrayList<>(); + for (int part : partsToBeTested) + res.add(new Integer[] {part}); + + return res; + } + + /** */ + @Test + public void testFit() { + int size = 100; + + Map<Integer, double[]> data = new HashMap<>(); + + Random rnd = new Random(0); + for (int i = 0; i < size; i++) { + double x = rnd.nextDouble() - 0.5; + data.put(i, new double[]{x, x > 0 ? 1 : 0}); + } + + DecisionTreeRegressionTrainer trainer = new DecisionTreeRegressionTrainer(1, 0); + + DecisionTreeNode tree = trainer.fit( + new LocalDatasetBuilder<>(data, parts), + (k, v) -> Arrays.copyOf(v, v.length - 1), + (k, v) -> v[v.length - 1] + ); + + assertTrue(tree instanceof DecisionTreeConditionalNode); + + DecisionTreeConditionalNode node = (DecisionTreeConditionalNode) tree; + + assertEquals(0, node.getThreshold(), 1e-3); + + assertTrue(node.getThenNode() instanceof DecisionTreeLeafNode); + assertTrue(node.getElseNode() instanceof DecisionTreeLeafNode); + + DecisionTreeLeafNode thenNode = (DecisionTreeLeafNode) node.getThenNode(); + DecisionTreeLeafNode elseNode = (DecisionTreeLeafNode) node.getElseNode(); + + assertEquals(1, thenNode.getVal(), 1e-10); + assertEquals(0, elseNode.getVal(), 1e-10); + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/9abfee69/modules/ml/src/test/java/org/apache/ignite/ml/tree/DecisionTreeTestSuite.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/tree/DecisionTreeTestSuite.java b/modules/ml/src/test/java/org/apache/ignite/ml/tree/DecisionTreeTestSuite.java new file mode 100644 index 0000000..2cbb486 --- /dev/null +++ b/modules/ml/src/test/java/org/apache/ignite/ml/tree/DecisionTreeTestSuite.java @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.tree; + +import org.apache.ignite.ml.tree.data.DecisionTreeDataTest; +import org.apache.ignite.ml.tree.impurity.gini.GiniImpurityMeasureCalculatorTest; +import org.apache.ignite.ml.tree.impurity.gini.GiniImpurityMeasureTest; +import org.apache.ignite.ml.tree.impurity.mse.MSEImpurityMeasureCalculatorTest; +import org.apache.ignite.ml.tree.impurity.mse.MSEImpurityMeasureTest; +import org.apache.ignite.ml.tree.impurity.util.SimpleStepFunctionCompressorTest; +import org.apache.ignite.ml.tree.impurity.util.StepFunctionTest; +import org.junit.runner.RunWith; +import org.junit.runners.Suite; + +/** + * Test suite for all tests located in {@link org.apache.ignite.ml.tree} package. + */ +@RunWith(Suite.class) +@Suite.SuiteClasses({ + DecisionTreeClassificationTrainerTest.class, + DecisionTreeRegressionTrainerTest.class, + DecisionTreeClassificationTrainerIntegrationTest.class, + DecisionTreeRegressionTrainerIntegrationTest.class, + DecisionTreeDataTest.class, + GiniImpurityMeasureCalculatorTest.class, + GiniImpurityMeasureTest.class, + MSEImpurityMeasureCalculatorTest.class, + MSEImpurityMeasureTest.class, + StepFunctionTest.class, + SimpleStepFunctionCompressorTest.class +}) +public class DecisionTreeTestSuite { +} http://git-wip-us.apache.org/repos/asf/ignite/blob/9abfee69/modules/ml/src/test/java/org/apache/ignite/ml/tree/data/DecisionTreeDataTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/tree/data/DecisionTreeDataTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/tree/data/DecisionTreeDataTest.java new file mode 100644 index 0000000..0c89d4e --- /dev/null +++ b/modules/ml/src/test/java/org/apache/ignite/ml/tree/data/DecisionTreeDataTest.java @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.tree.data; + +import org.junit.Test; + +import static org.junit.Assert.assertArrayEquals; + +/** + * Tests for {@link DecisionTreeData}. + */ +public class DecisionTreeDataTest { + /** */ + @Test + public void testFilter() { + double[][] features = new double[][]{{0}, {1}, {2}, {3}, {4}, {5}}; + double[] labels = new double[]{0, 1, 2, 3, 4, 5}; + + DecisionTreeData data = new DecisionTreeData(features, labels); + DecisionTreeData filteredData = data.filter(obj -> obj[0] > 2); + + assertArrayEquals(new double[][]{{3}, {4}, {5}}, filteredData.getFeatures()); + assertArrayEquals(new double[]{3, 4, 5}, filteredData.getLabels(), 1e-10); + } + + /** */ + @Test + public void testSort() { + double[][] features = new double[][]{{4, 1}, {3, 3}, {2, 0}, {1, 4}, {0, 2}}; + double[] labels = new double[]{0, 1, 2, 3, 4}; + + DecisionTreeData data = new DecisionTreeData(features, labels); + + data.sort(0); + + assertArrayEquals(new double[][]{{0, 2}, {1, 4}, {2, 0}, {3, 3}, {4, 1}}, features); + assertArrayEquals(new double[]{4, 3, 2, 1, 0}, labels, 1e-10); + + data.sort(1); + + assertArrayEquals(new double[][]{{2, 0}, {4, 1}, {0, 2}, {3, 3}, {1, 4}}, features); + assertArrayEquals(new double[]{2, 0, 4, 1, 3}, labels, 1e-10); + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/9abfee69/modules/ml/src/test/java/org/apache/ignite/ml/tree/impurity/gini/GiniImpurityMeasureCalculatorTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/tree/impurity/gini/GiniImpurityMeasureCalculatorTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/tree/impurity/gini/GiniImpurityMeasureCalculatorTest.java new file mode 100644 index 0000000..afd81e8 --- /dev/null +++ b/modules/ml/src/test/java/org/apache/ignite/ml/tree/impurity/gini/GiniImpurityMeasureCalculatorTest.java @@ -0,0 +1,103 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.tree.impurity.gini; + +import java.util.HashMap; +import java.util.Map; +import org.apache.ignite.ml.tree.data.DecisionTreeData; +import org.apache.ignite.ml.tree.impurity.util.StepFunction; +import org.junit.Test; + +import static junit.framework.TestCase.assertEquals; +import static org.junit.Assert.assertArrayEquals; + +/** + * Tests for {@link GiniImpurityMeasureCalculator}. + */ +public class GiniImpurityMeasureCalculatorTest { + /** */ + @Test + public void testCalculate() { + double[][] data = new double[][]{{0, 1}, {1, 0}, {2, 2}, {3, 3}}; + double[] labels = new double[]{0, 1, 1, 1}; + + Map<Double, Integer> encoder = new HashMap<>(); + encoder.put(0.0, 0); + encoder.put(1.0, 1); + GiniImpurityMeasureCalculator calculator = new GiniImpurityMeasureCalculator(encoder); + + StepFunction<GiniImpurityMeasure>[] impurity = calculator.calculate(new DecisionTreeData(data, labels)); + + assertEquals(2, impurity.length); + + // Check Gini calculated for the first column. + assertArrayEquals(new double[]{Double.NEGATIVE_INFINITY, 0, 1, 2, 3}, impurity[0].getX(), 1e-10); + assertEquals(-2.500, impurity[0].getY()[0].impurity(), 1e-3); + assertEquals(-4.000, impurity[0].getY()[1].impurity(),1e-3); + assertEquals(-3.000, impurity[0].getY()[2].impurity(),1e-3); + assertEquals(-2.666, impurity[0].getY()[3].impurity(),1e-3); + assertEquals(-2.500, impurity[0].getY()[4].impurity(),1e-3); + + // Check Gini calculated for the second column. + assertArrayEquals(new double[]{Double.NEGATIVE_INFINITY, 0, 1, 2, 3}, impurity[1].getX(), 1e-10); + assertEquals(-2.500, impurity[1].getY()[0].impurity(),1e-3); + assertEquals(-2.666, impurity[1].getY()[1].impurity(),1e-3); + assertEquals(-3.000, impurity[1].getY()[2].impurity(),1e-3); + assertEquals(-2.666, impurity[1].getY()[3].impurity(),1e-3); + assertEquals(-2.500, impurity[1].getY()[4].impurity(),1e-3); + } + + /** */ + @Test + public void testCalculateWithRepeatedData() { + double[][] data = new double[][]{{0}, {1}, {2}, {2}, {3}}; + double[] labels = new double[]{0, 1, 1, 1, 1}; + + Map<Double, Integer> encoder = new HashMap<>(); + encoder.put(0.0, 0); + encoder.put(1.0, 1); + GiniImpurityMeasureCalculator calculator = new GiniImpurityMeasureCalculator(encoder); + + StepFunction<GiniImpurityMeasure>[] impurity = calculator.calculate(new DecisionTreeData(data, labels)); + + assertEquals(1, impurity.length); + + // Check Gini calculated for the first column. + assertArrayEquals(new double[]{Double.NEGATIVE_INFINITY, 0, 1, 2, 3}, impurity[0].getX(), 1e-10); + assertEquals(-3.400, impurity[0].getY()[0].impurity(), 1e-3); + assertEquals(-5.000, impurity[0].getY()[1].impurity(),1e-3); + assertEquals(-4.000, impurity[0].getY()[2].impurity(),1e-3); + assertEquals(-3.500, impurity[0].getY()[3].impurity(),1e-3); + assertEquals(-3.400, impurity[0].getY()[4].impurity(),1e-3); + } + + /** */ + @Test + public void testGetLabelCode() { + Map<Double, Integer> encoder = new HashMap<>(); + encoder.put(0.0, 0); + encoder.put(1.0, 1); + encoder.put(2.0, 2); + + GiniImpurityMeasureCalculator calculator = new GiniImpurityMeasureCalculator(encoder); + + assertEquals(0, calculator.getLabelCode(0.0)); + assertEquals(1, calculator.getLabelCode(1.0)); + assertEquals(2, calculator.getLabelCode(2.0)); + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/9abfee69/modules/ml/src/test/java/org/apache/ignite/ml/tree/impurity/gini/GiniImpurityMeasureTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/tree/impurity/gini/GiniImpurityMeasureTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/tree/impurity/gini/GiniImpurityMeasureTest.java new file mode 100644 index 0000000..35c456a --- /dev/null +++ b/modules/ml/src/test/java/org/apache/ignite/ml/tree/impurity/gini/GiniImpurityMeasureTest.java @@ -0,0 +1,131 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.tree.impurity.gini; + +import java.util.Random; +import org.junit.Test; + +import static junit.framework.TestCase.assertEquals; + +/** + * Tests for {@link GiniImpurityMeasure}. + */ +public class GiniImpurityMeasureTest { + /** */ + @Test + public void testImpurityOnEmptyData() { + long[] left = new long[]{0, 0, 0}; + long[] right = new long[]{0, 0, 0}; + + GiniImpurityMeasure impurity = new GiniImpurityMeasure(left, right); + + assertEquals(0.0, impurity.impurity(), 1e-10); + } + + /** */ + @Test + public void testImpurityLeftPart() { + long[] left = new long[]{3, 0, 0}; + long[] right = new long[]{0, 0, 0}; + + GiniImpurityMeasure impurity = new GiniImpurityMeasure(left, right); + + assertEquals(-3, impurity.impurity(), 1e-10); + } + + /** */ + @Test + public void testImpurityRightPart() { + long[] left = new long[]{0, 0, 0}; + long[] right = new long[]{3, 0, 0}; + + GiniImpurityMeasure impurity = new GiniImpurityMeasure(left, right); + + assertEquals(-3, impurity.impurity(), 1e-10); + } + + /** */ + @Test + public void testImpurityLeftAndRightPart() { + long[] left = new long[]{3, 0, 0}; + long[] right = new long[]{0, 3, 0}; + + GiniImpurityMeasure impurity = new GiniImpurityMeasure(left, right); + + assertEquals(-6, impurity.impurity(), 1e-10); + } + + /** */ + @Test + public void testAdd() { + Random rnd = new Random(0); + + GiniImpurityMeasure a = new GiniImpurityMeasure( + new long[]{randCnt(rnd), randCnt(rnd), randCnt(rnd)}, + new long[]{randCnt(rnd), randCnt(rnd), randCnt(rnd)} + ); + + + GiniImpurityMeasure b = new GiniImpurityMeasure( + new long[]{randCnt(rnd), randCnt(rnd), randCnt(rnd)}, + new long[]{randCnt(rnd), randCnt(rnd), randCnt(rnd)} + ); + + GiniImpurityMeasure c = a.add(b); + + assertEquals(a.getLeft()[0] + b.getLeft()[0], c.getLeft()[0]); + assertEquals(a.getLeft()[1] + b.getLeft()[1], c.getLeft()[1]); + assertEquals(a.getLeft()[2] + b.getLeft()[2], c.getLeft()[2]); + + assertEquals(a.getRight()[0] + b.getRight()[0], c.getRight()[0]); + assertEquals(a.getRight()[1] + b.getRight()[1], c.getRight()[1]); + assertEquals(a.getRight()[2] + b.getRight()[2], c.getRight()[2]); + } + + /** */ + @Test + public void testSubtract() { + Random rnd = new Random(0); + + GiniImpurityMeasure a = new GiniImpurityMeasure( + new long[]{randCnt(rnd), randCnt(rnd), randCnt(rnd)}, + new long[]{randCnt(rnd), randCnt(rnd), randCnt(rnd)} + ); + + + GiniImpurityMeasure b = new GiniImpurityMeasure( + new long[]{randCnt(rnd), randCnt(rnd), randCnt(rnd)}, + new long[]{randCnt(rnd), randCnt(rnd), randCnt(rnd)} + ); + + GiniImpurityMeasure c = a.subtract(b); + + assertEquals(a.getLeft()[0] - b.getLeft()[0], c.getLeft()[0]); + assertEquals(a.getLeft()[1] - b.getLeft()[1], c.getLeft()[1]); + assertEquals(a.getLeft()[2] - b.getLeft()[2], c.getLeft()[2]); + + assertEquals(a.getRight()[0] - b.getRight()[0], c.getRight()[0]); + assertEquals(a.getRight()[1] - b.getRight()[1], c.getRight()[1]); + assertEquals(a.getRight()[2] - b.getRight()[2], c.getRight()[2]); + } + + /** Generates random count. */ + private long randCnt(Random rnd) { + return Math.abs(rnd.nextInt()); + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/9abfee69/modules/ml/src/test/java/org/apache/ignite/ml/tree/impurity/mse/MSEImpurityMeasureCalculatorTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/tree/impurity/mse/MSEImpurityMeasureCalculatorTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/tree/impurity/mse/MSEImpurityMeasureCalculatorTest.java new file mode 100644 index 0000000..510c18f --- /dev/null +++ b/modules/ml/src/test/java/org/apache/ignite/ml/tree/impurity/mse/MSEImpurityMeasureCalculatorTest.java @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.tree.impurity.mse; + +import org.apache.ignite.ml.tree.data.DecisionTreeData; +import org.apache.ignite.ml.tree.impurity.util.StepFunction; +import org.junit.Test; + +import static junit.framework.TestCase.assertEquals; +import static org.junit.Assert.assertArrayEquals; + +/** + * Tests for {@link MSEImpurityMeasureCalculator}. + */ +public class MSEImpurityMeasureCalculatorTest { + /** */ + @Test + public void testCalculate() { + double[][] data = new double[][]{{0, 2}, {1, 1}, {2, 0}, {3, 3}}; + double[] labels = new double[]{1, 2, 2, 1}; + + MSEImpurityMeasureCalculator calculator = new MSEImpurityMeasureCalculator(); + + StepFunction<MSEImpurityMeasure>[] impurity = calculator.calculate(new DecisionTreeData(data, labels)); + + assertEquals(2, impurity.length); + + // Test MSE calculated for the first column. + assertArrayEquals(new double[]{Double.NEGATIVE_INFINITY, 0, 1, 2, 3}, impurity[0].getX(), 1e-10); + assertEquals(1.000, impurity[0].getY()[0].impurity(), 1e-3); + assertEquals(0.666, impurity[0].getY()[1].impurity(),1e-3); + assertEquals(1.000, impurity[0].getY()[2].impurity(),1e-3); + assertEquals(0.666, impurity[0].getY()[3].impurity(),1e-3); + assertEquals(1.000, impurity[0].getY()[4].impurity(),1e-3); + + // Test MSE calculated for the second column. + assertArrayEquals(new double[]{Double.NEGATIVE_INFINITY, 0, 1, 2, 3}, impurity[1].getX(), 1e-10); + assertEquals(1.000, impurity[1].getY()[0].impurity(),1e-3); + assertEquals(0.666, impurity[1].getY()[1].impurity(),1e-3); + assertEquals(0.000, impurity[1].getY()[2].impurity(),1e-3); + assertEquals(0.666, impurity[1].getY()[3].impurity(),1e-3); + assertEquals(1.000, impurity[1].getY()[4].impurity(),1e-3); + } +}