http://git-wip-us.apache.org/repos/asf/ignite/blob/16a7c980/modules/ml/src/main/java/org/apache/ignite/ml/composition/BaggingModelTrainer.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/composition/BaggingModelTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/composition/BaggingModelTrainer.java new file mode 100644 index 0000000..f721d53 --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/composition/BaggingModelTrainer.java @@ -0,0 +1,209 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.composition; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Random; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Future; +import java.util.stream.IntStream; +import org.apache.ignite.ml.Model; +import org.apache.ignite.ml.composition.predictionsaggregator.PredictionsAggregator; +import org.apache.ignite.ml.dataset.DatasetBuilder; +import org.apache.ignite.ml.math.functions.IgniteBiFunction; +import org.apache.ignite.ml.math.functions.IgniteFunction; +import org.apache.ignite.ml.selection.split.mapper.SHA256UniformMapper; +import org.apache.ignite.ml.trainers.DatasetTrainer; +import org.apache.ignite.ml.util.Utils; +import org.jetbrains.annotations.NotNull; + +/** + * Abstract trainer implementing bagging logic. + */ +public abstract class BaggingModelTrainer implements DatasetTrainer<ModelsComposition, Double> { + /** + * Predictions aggregator. + */ + private final PredictionsAggregator predictionsAggregator; + /** + * Number of features to draw from original features vector to train each model. + */ + private final int maximumFeaturesCntPerMdl; + /** + * Ensemble size. + */ + private final int ensembleSize; + /** + * Size of sample part in percent to train one model. + */ + private final double samplePartSizePerMdl; + /** + * Feature vector size. + */ + private final int featureVectorSize; + /** + * Learning thread pool. + */ + private final ExecutorService threadPool; + + /** + * Constructs new instance of BaggingModelTrainer. + * + * @param predictionsAggregator Predictions aggregator. + * @param featureVectorSize Feature vector size. + * @param maximumFeaturesCntPerMdl Number of features to draw from original features vector to train each model. + * @param ensembleSize Ensemble size. + * @param samplePartSizePerMdl Size of sample part in percent to train one model. + */ + public BaggingModelTrainer(PredictionsAggregator predictionsAggregator, + int featureVectorSize, + int maximumFeaturesCntPerMdl, + int ensembleSize, + double samplePartSizePerMdl) { + + this(predictionsAggregator, featureVectorSize, maximumFeaturesCntPerMdl, ensembleSize, + samplePartSizePerMdl, null); + } + + /** + * Constructs new instance of BaggingModelTrainer. + * + * @param predictionsAggregator Predictions aggregator. + * @param featureVectorSize Feature vector size. + * @param maximumFeaturesCntPerMdl Number of features to draw from original features vector to train each model. + * @param ensembleSize Ensemble size. + * @param samplePartSizePerMdl Size of sample part in percent to train one model. + * @param threadPool Learning thread pool. + */ + public BaggingModelTrainer(PredictionsAggregator predictionsAggregator, + int featureVectorSize, + int maximumFeaturesCntPerMdl, + int ensembleSize, + double samplePartSizePerMdl, + ExecutorService threadPool) { + + this.predictionsAggregator = predictionsAggregator; + this.maximumFeaturesCntPerMdl = maximumFeaturesCntPerMdl; + this.ensembleSize = ensembleSize; + this.samplePartSizePerMdl = samplePartSizePerMdl; + this.featureVectorSize = featureVectorSize; + this.threadPool = threadPool; + } + + /** {@inheritDoc} */ + @Override public <K, V> ModelsComposition fit(DatasetBuilder<K, V> datasetBuilder, + IgniteBiFunction<K, V, double[]> featureExtractor, + IgniteBiFunction<K, V, Double> lbExtractor) { + + List<ModelsComposition.ModelOnFeaturesSubspace> learnedModels = new ArrayList<>(); + List<Future<ModelsComposition.ModelOnFeaturesSubspace>> futures = new ArrayList<>(); + + for (int i = 0; i < ensembleSize; i++) { + if (threadPool == null) + learnedModels.add(learnModel(datasetBuilder, featureExtractor, lbExtractor)); + else { + Future<ModelsComposition.ModelOnFeaturesSubspace> fut = threadPool.submit(() -> { + return learnModel(datasetBuilder, featureExtractor, lbExtractor); + }); + + futures.add(fut); + } + } + + if (threadPool != null) { + for (Future<ModelsComposition.ModelOnFeaturesSubspace> future : futures) { + try { + learnedModels.add(future.get()); + } + catch (InterruptedException | ExecutionException e) { + throw new RuntimeException(e); + } + } + } + + return new ModelsComposition(learnedModels, predictionsAggregator); + } + + /** + * Trains one model on part of sample and features subspace. + * + * @param datasetBuilder Dataset builder. + * @param featureExtractor Feature extractor. + * @param lbExtractor Label extractor. + */ + @NotNull private <K, V> ModelsComposition.ModelOnFeaturesSubspace learnModel( + DatasetBuilder<K, V> datasetBuilder, + IgniteBiFunction<K, V, double[]> featureExtractor, + IgniteBiFunction<K, V, Double> lbExtractor) { + + Random rnd = new Random(); + SHA256UniformMapper<K, V> sampleFilter = new SHA256UniformMapper<>(rnd); + long featureExtractorSeed = rnd.nextLong(); + Map<Integer, Integer> featuresMapping = createFeaturesMapping(featureExtractorSeed, featureVectorSize); + + //TODO: IGNITE-8867 Need to implement bootstrapping algorithm + Model<double[], Double> mdl = buildDatasetTrainerForModel().fit( + datasetBuilder.withFilter((features, answer) -> sampleFilter.map(features, answer) < samplePartSizePerMdl), + wrapFeatureExtractor(featureExtractor, featuresMapping), + lbExtractor); + + return new ModelsComposition.ModelOnFeaturesSubspace(featuresMapping, mdl); + } + + /** + * Constructs mapping from original feature vector to subspace. + * + * @param seed Seed. + * @param featuresVectorSize Features vector size. + */ + private Map<Integer, Integer> createFeaturesMapping(long seed, int featuresVectorSize) { + int[] featureIdxs = Utils.selectKDistinct(featuresVectorSize, maximumFeaturesCntPerMdl, new Random(seed)); + Map<Integer, Integer> locFeaturesMapping = new HashMap<>(); + + IntStream.range(0, maximumFeaturesCntPerMdl) + .forEach(localId -> locFeaturesMapping.put(localId, featureIdxs[localId])); + + return locFeaturesMapping; + } + + /** + * Creates trainer specific to ensemble. + */ + protected abstract DatasetTrainer<? extends Model<double[], Double>, Double> buildDatasetTrainerForModel(); + + /** + * Wraps the original feature extractor with features subspace mapping applying. + * + * @param featureExtractor Feature extractor. + * @param featureMapping Feature mapping. + */ + private <K, V> IgniteBiFunction<K, V, double[]> wrapFeatureExtractor( + IgniteBiFunction<K, V, double[]> featureExtractor, + Map<Integer, Integer> featureMapping) { + + return featureExtractor.andThen((IgniteFunction<double[], double[]>)featureValues -> { + double[] newFeaturesValues = new double[featureMapping.size()]; + featureMapping.forEach((localId, featureValueId) -> newFeaturesValues[localId] = featureValues[featureValueId]); + return newFeaturesValues; + }); + } +}
http://git-wip-us.apache.org/repos/asf/ignite/blob/16a7c980/modules/ml/src/main/java/org/apache/ignite/ml/composition/ModelsComposition.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/composition/ModelsComposition.java b/modules/ml/src/main/java/org/apache/ignite/ml/composition/ModelsComposition.java new file mode 100644 index 0000000..1de82e3 --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/composition/ModelsComposition.java @@ -0,0 +1,129 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.composition; + +import java.util.Collections; +import java.util.List; +import java.util.Map; +import org.apache.ignite.ml.Model; +import org.apache.ignite.ml.composition.predictionsaggregator.PredictionsAggregator; + +/** + * Model consisting of several models and prediction aggregation strategy. + */ +public class ModelsComposition implements Model<double[], Double> { + /** + * Predictions aggregator. + */ + private final PredictionsAggregator predictionsAggregator; + /** + * Models. + */ + private final List<ModelOnFeaturesSubspace> models; + + /** + * Constructs a new instance of composition of models. + * + * @param models Basic models. + * @param predictionsAggregator Predictions aggregator. + */ + public ModelsComposition(List<ModelOnFeaturesSubspace> models, PredictionsAggregator predictionsAggregator) { + this.predictionsAggregator = predictionsAggregator; + this.models = Collections.unmodifiableList(models); + } + + /** + * Applies containing models to features and aggregate them to one prediction. + * + * @param features Features vector. + * @return Estimation. + */ + @Override public Double apply(double[] features) { + double[] predictions = new double[models.size()]; + + for (int i = 0; i < models.size(); i++) + predictions[i] = models.get(i).apply(features); + + return predictionsAggregator.apply(predictions); + } + + /** + * Returns predictions aggregator. + */ + public PredictionsAggregator getPredictionsAggregator() { + return predictionsAggregator; + } + + /** + * Returns containing models. + */ + public List<ModelOnFeaturesSubspace> getModels() { + return models; + } + + /** + * Model trained on a features subspace with mapping from original features space to subspace. + */ + public static class ModelOnFeaturesSubspace implements Model<double[], Double> { + /** + * Features mapping to subspace. + */ + private final Map<Integer, Integer> featuresMapping; + /** + * Trained model of features subspace. + */ + private final Model<double[], Double> model; + + /** + * Constructs new instance of ModelOnFeaturesSubspace. + * + * @param featuresMapping Features mapping to subspace. + * @param mdl Learned model. + */ + ModelOnFeaturesSubspace(Map<Integer, Integer> featuresMapping, Model<double[], Double> mdl) { + this.featuresMapping = Collections.unmodifiableMap(featuresMapping); + this.model = mdl; + } + + /** + * Projects features vector to subspace in according to mapping and apply model to it. + * + * @param features Features vector. + * @return Estimation. + */ + @Override public Double apply(double[] features) { + double[] newFeatures = new double[featuresMapping.size()]; + featuresMapping.forEach((localId, featureVectorId) -> newFeatures[localId] = features[featureVectorId]); + return model.apply(newFeatures); + } + + /** + * Returns features mapping. + */ + public Map<Integer, Integer> getFeaturesMapping() { + return featuresMapping; + } + + /** + * Returns model. + */ + public Model<double[], Double> getModel() { + return model; + } + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/16a7c980/modules/ml/src/main/java/org/apache/ignite/ml/composition/package-info.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/composition/package-info.java b/modules/ml/src/main/java/org/apache/ignite/ml/composition/package-info.java new file mode 100644 index 0000000..8da3668 --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/composition/package-info.java @@ -0,0 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * <!-- Package description. --> + * Contains classes for ensemble of models implementation. + */ +package org.apache.ignite.ml.composition; http://git-wip-us.apache.org/repos/asf/ignite/blob/16a7c980/modules/ml/src/main/java/org/apache/ignite/ml/composition/predictionsaggregator/MeanValuePredictionsAggregator.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/composition/predictionsaggregator/MeanValuePredictionsAggregator.java b/modules/ml/src/main/java/org/apache/ignite/ml/composition/predictionsaggregator/MeanValuePredictionsAggregator.java new file mode 100644 index 0000000..01e693d --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/composition/predictionsaggregator/MeanValuePredictionsAggregator.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.composition.predictionsaggregator; + +import java.util.Arrays; +import org.apache.ignite.internal.util.typedef.internal.A; + +/** + * Predictions aggregator returning the mean value of predictions. + */ +public class MeanValuePredictionsAggregator implements PredictionsAggregator { + /** {@inheritDoc} */ + @Override public Double apply(double[] estimations) { + A.notEmpty(estimations, "estimations vector"); + return Arrays.stream(estimations).reduce(0.0, Double::sum) / estimations.length; + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/16a7c980/modules/ml/src/main/java/org/apache/ignite/ml/composition/predictionsaggregator/OnMajorityPredictionsAggregator.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/composition/predictionsaggregator/OnMajorityPredictionsAggregator.java b/modules/ml/src/main/java/org/apache/ignite/ml/composition/predictionsaggregator/OnMajorityPredictionsAggregator.java new file mode 100644 index 0000000..cd84a7e --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/composition/predictionsaggregator/OnMajorityPredictionsAggregator.java @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.composition.predictionsaggregator; + +import java.util.Comparator; +import java.util.HashMap; +import java.util.Map; +import org.apache.ignite.internal.util.typedef.internal.A; + +/** + * Predictions aggregator returning the most frequently prediction. + */ +public class OnMajorityPredictionsAggregator implements PredictionsAggregator { + /** {@inheritDoc} */ + @Override public Double apply(double[] estimations) { + A.notEmpty(estimations, "estimations vector"); + + Map<Double, Integer> cntrsByCls = new HashMap<>(); + + for (Double predictedValue : estimations) { + Integer cntrVal = cntrsByCls.getOrDefault(predictedValue, 0) + 1; + cntrsByCls.put(predictedValue, cntrVal); + } + + return cntrsByCls.entrySet().stream() + .max(Comparator.comparing(Map.Entry::getValue)) + .get().getKey(); + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/16a7c980/modules/ml/src/main/java/org/apache/ignite/ml/composition/predictionsaggregator/PredictionsAggregator.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/composition/predictionsaggregator/PredictionsAggregator.java b/modules/ml/src/main/java/org/apache/ignite/ml/composition/predictionsaggregator/PredictionsAggregator.java new file mode 100644 index 0000000..86b1e96 --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/composition/predictionsaggregator/PredictionsAggregator.java @@ -0,0 +1,26 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.composition.predictionsaggregator; + +import org.apache.ignite.ml.math.functions.IgniteFunction; + +/** + * Predictions aggregator interface. + */ +public interface PredictionsAggregator extends IgniteFunction<double[], Double> { +} http://git-wip-us.apache.org/repos/asf/ignite/blob/16a7c980/modules/ml/src/main/java/org/apache/ignite/ml/composition/predictionsaggregator/package-info.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/composition/predictionsaggregator/package-info.java b/modules/ml/src/main/java/org/apache/ignite/ml/composition/predictionsaggregator/package-info.java new file mode 100644 index 0000000..a43aa44 --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/composition/predictionsaggregator/package-info.java @@ -0,0 +1,23 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * <!-- Package description. --> + * Contains classes for several predictions aggregation strategies + * working with predictions vector from models ensemble. + */ +package org.apache.ignite.ml.composition.predictionsaggregator; http://git-wip-us.apache.org/repos/asf/ignite/blob/16a7c980/modules/ml/src/main/java/org/apache/ignite/ml/dataset/DatasetBuilder.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/dataset/DatasetBuilder.java b/modules/ml/src/main/java/org/apache/ignite/ml/dataset/DatasetBuilder.java index a6757ff..19bdde9 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/dataset/DatasetBuilder.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/dataset/DatasetBuilder.java @@ -18,6 +18,7 @@ package org.apache.ignite.ml.dataset; import java.io.Serializable; +import org.apache.ignite.lang.IgniteBiPredicate; import org.apache.ignite.ml.dataset.impl.cache.CacheBasedDatasetBuilder; import org.apache.ignite.ml.dataset.impl.local.LocalDatasetBuilder; @@ -46,4 +47,11 @@ public interface DatasetBuilder<K, V> { */ public <C extends Serializable, D extends AutoCloseable> Dataset<C, D> build( PartitionContextBuilder<K, V, C> partCtxBuilder, PartitionDataBuilder<K, V, C, D> partDataBuilder); + + + /** + * Returns new instance of DatasetBuilder using conjunction of internal filter and {@code filterToAdd}. + * @param filterToAdd Additional filter. + */ + public DatasetBuilder<K,V> withFilter(IgniteBiPredicate<K,V> filterToAdd); } http://git-wip-us.apache.org/repos/asf/ignite/blob/16a7c980/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/cache/CacheBasedDatasetBuilder.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/cache/CacheBasedDatasetBuilder.java b/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/cache/CacheBasedDatasetBuilder.java index b66c8aa..335ce63 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/cache/CacheBasedDatasetBuilder.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/cache/CacheBasedDatasetBuilder.java @@ -110,4 +110,10 @@ public class CacheBasedDatasetBuilder<K, V> implements DatasetBuilder<K, V> { return new CacheBasedDataset<>(ignite, upstreamCache, filter, datasetCache, partDataBuilder, datasetId); } + + /** {@inheritDoc} */ + @Override public DatasetBuilder<K, V> withFilter(IgniteBiPredicate<K, V> filterToAdd) { + return new CacheBasedDatasetBuilder<>(ignite, upstreamCache, + (e1, e2) -> filter.apply(e1, e2) && filterToAdd.apply(e1, e2)); + } } http://git-wip-us.apache.org/repos/asf/ignite/blob/16a7c980/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/local/LocalDatasetBuilder.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/local/LocalDatasetBuilder.java b/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/local/LocalDatasetBuilder.java index a4f275d..2586759 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/local/LocalDatasetBuilder.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/local/LocalDatasetBuilder.java @@ -113,6 +113,12 @@ public class LocalDatasetBuilder<K, V> implements DatasetBuilder<K, V> { return new LocalDataset<>(ctxList, dataList); } + /** {@inheritDoc} */ + @Override public DatasetBuilder<K, V> withFilter(IgniteBiPredicate<K, V> filterToAdd) { + return new LocalDatasetBuilder<>(upstreamMap, + (e1, e2) -> filter.apply(e1, e2) && filterToAdd.apply(e1, e2), partitions); + } + /** * Utils class that wraps iterator so that it produces only specified number of entries and allows to transform * entries from one type to another. http://git-wip-us.apache.org/repos/asf/ignite/blob/16a7c980/modules/ml/src/main/java/org/apache/ignite/ml/selection/split/mapper/SHA256UniformMapper.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/selection/split/mapper/SHA256UniformMapper.java b/modules/ml/src/main/java/org/apache/ignite/ml/selection/split/mapper/SHA256UniformMapper.java index b0475ca..324d108 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/selection/split/mapper/SHA256UniformMapper.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/selection/split/mapper/SHA256UniformMapper.java @@ -60,7 +60,7 @@ public class SHA256UniformMapper<K, V> implements UniformMapper<K,V> { /** {@inheritDoc} */ @Override public double map(K key, V val) { - int h = key.hashCode(); + int h = Math.abs(key.hashCode()); String str = String.valueOf(key.hashCode()); byte[] hash = getDigest().digest(str.getBytes(StandardCharsets.UTF_8)); http://git-wip-us.apache.org/repos/asf/ignite/blob/16a7c980/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTree.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTree.java b/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTree.java index c0b88fc..4d95ff3 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTree.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTree.java @@ -38,7 +38,7 @@ import org.apache.ignite.ml.tree.leaf.DecisionTreeLeafBuilder; * * @param <T> Type of impurity measure. */ -abstract class DecisionTree<T extends ImpurityMeasure<T>> implements DatasetTrainer<DecisionTreeNode, Double> { +public abstract class DecisionTree<T extends ImpurityMeasure<T>> implements DatasetTrainer<DecisionTreeNode, Double> { /** Max tree deep. */ private final int maxDeep; @@ -66,6 +66,7 @@ abstract class DecisionTree<T extends ImpurityMeasure<T>> implements DatasetTrai this.decisionTreeLeafBuilder = decisionTreeLeafBuilder; } + /** {@inheritDoc} */ @Override public <K, V> DecisionTreeNode fit(DatasetBuilder<K, V> datasetBuilder, IgniteBiFunction<K, V, double[]> featureExtractor, IgniteBiFunction<K, V, Double> lbExtractor) { http://git-wip-us.apache.org/repos/asf/ignite/blob/16a7c980/modules/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/RandomForestClassifierTrainer.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/RandomForestClassifierTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/RandomForestClassifierTrainer.java new file mode 100644 index 0000000..bb99515 --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/RandomForestClassifierTrainer.java @@ -0,0 +1,103 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.tree.randomforest; + +import java.util.concurrent.ExecutorService; +import org.apache.ignite.ml.composition.predictionsaggregator.OnMajorityPredictionsAggregator; +import org.apache.ignite.ml.composition.predictionsaggregator.PredictionsAggregator; +import org.apache.ignite.ml.trainers.DatasetTrainer; +import org.apache.ignite.ml.tree.DecisionTreeClassificationTrainer; +import org.apache.ignite.ml.tree.DecisionTreeNode; + +/** + * Random forest classifier trainer. + */ +public class RandomForestClassifierTrainer extends RandomForestTrainer { + /** + * Constructs new instance of RandomForestClassifierTrainer. + * + * @param predictionsAggregator Predictions aggregator. + * @param featureVectorSize Feature vector size. + * @param maximumFeaturesCntPerMdl Number of features to draw from original features vector to train each model. + * @param ensembleSize Ensemble size. + * @param samplePartSizePerMdl Size of sample part in percent to train one model. + * @param maxDeep Max decision tree deep. + * @param minImpurityDecrease Min impurity decrease. + * @param threadPool Learning thread pool. + */ + public RandomForestClassifierTrainer(PredictionsAggregator predictionsAggregator, + int featureVectorSize, + int maximumFeaturesCntPerMdl, + int ensembleSize, + double samplePartSizePerMdl, + int maxDeep, + double minImpurityDecrease, + ExecutorService threadPool) { + + super(predictionsAggregator, featureVectorSize, maximumFeaturesCntPerMdl, + ensembleSize, samplePartSizePerMdl, maxDeep, minImpurityDecrease, threadPool); + } + + /** + * Constructs new instance of RandomForestClassifierTrainer. + * + * @param featureVectorSize Feature vector size. + * @param maximumFeaturesCntPerMdl Number of features to draw from original features vector to train each model. + * @param ensembleSize Ensemble size. + * @param samplePartSizePerMdl Size of sample part in percent to train one model. + * @param maxDeep Max decision tree deep. + * @param minImpurityDecrease Min impurity decrease. + * @param threadPool Learning thread pool. + */ + public RandomForestClassifierTrainer(int featureVectorSize, + int maximumFeaturesCntPerMdl, + int ensembleSize, + double samplePartSizePerMdl, + int maxDeep, double minImpurityDecrease, + ExecutorService threadPool) { + + this(new OnMajorityPredictionsAggregator(), featureVectorSize, maximumFeaturesCntPerMdl, + ensembleSize, samplePartSizePerMdl, maxDeep, minImpurityDecrease, threadPool); + } + + /** + * Constructs new instance of RandomForestClassifierTrainer. + * + * @param featureVectorSize Feature vector size. + * @param maximumFeaturesCntPerMdl Number of features to draw from original features vector to train each model. + * @param ensembleSize Ensemble size. + * @param samplePartSizePerMdl Size of sample part in percent to train one model. + * @param maxDeep Max decision tree deep. + * @param minImpurityDecrease Min impurity decrease. + */ + public RandomForestClassifierTrainer(int featureVectorSize, + int maximumFeaturesCntPerMdl, + int ensembleSize, + double samplePartSizePerMdl, + int maxDeep, + double minImpurityDecrease) { + + this(new OnMajorityPredictionsAggregator(), featureVectorSize, maximumFeaturesCntPerMdl, + ensembleSize, samplePartSizePerMdl, maxDeep, minImpurityDecrease, null); + } + + /** {@inheritDoc} */ + @Override protected DatasetTrainer<DecisionTreeNode, Double> buildDatasetTrainerForModel() { + return new DecisionTreeClassificationTrainer(maxDeep, minImpurityDecrease); + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/16a7c980/modules/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/RandomForestRegressionTrainer.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/RandomForestRegressionTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/RandomForestRegressionTrainer.java new file mode 100644 index 0000000..d317683 --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/RandomForestRegressionTrainer.java @@ -0,0 +1,103 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.tree.randomforest; + +import java.util.concurrent.ExecutorService; +import org.apache.ignite.ml.composition.predictionsaggregator.MeanValuePredictionsAggregator; +import org.apache.ignite.ml.composition.predictionsaggregator.PredictionsAggregator; +import org.apache.ignite.ml.trainers.DatasetTrainer; +import org.apache.ignite.ml.tree.DecisionTreeNode; +import org.apache.ignite.ml.tree.DecisionTreeRegressionTrainer; + +/** + * Random forest regression trainer. + */ +public class RandomForestRegressionTrainer extends RandomForestTrainer { + /** + * Constructs new instance of RandomForestRegressionTrainer. + * + * @param predictionsAggregator Predictions aggregator. + * @param featureVectorSize Feature vector size. + * @param maximumFeaturesCntPerMdl Number of features to draw from original features vector to train each model. + * @param ensembleSize Ensemble size. + * @param samplePartSizePerMdl Size of sample part in percent to train one model. + * @param maxDeep Max decision tree deep. + * @param minImpurityDecrease Min impurity decrease. + * @param threadPool Learning thread pool. + */ + public RandomForestRegressionTrainer(PredictionsAggregator predictionsAggregator, + int featureVectorSize, + int maximumFeaturesCntPerMdl, + int ensembleSize, + double samplePartSizePerMdl, + int maxDeep, + double minImpurityDecrease, + ExecutorService threadPool) { + + super(predictionsAggregator, featureVectorSize, maximumFeaturesCntPerMdl, + ensembleSize, samplePartSizePerMdl, maxDeep, minImpurityDecrease, threadPool); + } + + /** + * Constructs new instance of RandomForestRegressionTrainer. + * + * @param featureVectorSize Feature vector size. + * @param maximumFeaturesCntPerMdl Number of features to draw from original features vector to train each model. + * @param ensembleSize Ensemble size. + * @param samplePartSizePerMdl Size of sample part in percent to train one model. + * @param maxDeep Max decision tree deep. + * @param minImpurityDecrease Min impurity decrease. + * @param threadPool Learning thread pool. + */ + public RandomForestRegressionTrainer(int featureVectorSize, + int maximumFeaturesCntPerMdl, + int ensembleSize, + double samplePartSizePerMdl, + int maxDeep, + double minImpurityDecrease, + ExecutorService threadPool) { + + this(new MeanValuePredictionsAggregator(), featureVectorSize, maximumFeaturesCntPerMdl, + ensembleSize, samplePartSizePerMdl, maxDeep, minImpurityDecrease, threadPool); + } + + /** + * Constructs new instance of RandomForestRegressionTrainer. + * + * @param featureVectorSize Feature vector size. + * @param maximumFeaturesCntPerMdl Number of features to draw from original features vector to train each model. + * @param ensembleSize Ensemble size. + * @param samplePartSizePerMdl Size of sample part in percent to train one model. + * @param maxDeep Max decision tree deep. + * @param minImpurityDecrease Min impurity decrease. + */ + public RandomForestRegressionTrainer(int featureVectorSize, + int maximumFeaturesCntPerMdl, + int ensembleSize, + double samplePartSizePerMdl, + int maxDeep, double minImpurityDecrease) { + + this(new MeanValuePredictionsAggregator(), featureVectorSize, maximumFeaturesCntPerMdl, + ensembleSize, samplePartSizePerMdl, maxDeep, minImpurityDecrease, null); + } + + /** {@inheritDoc} */ + @Override protected DatasetTrainer<DecisionTreeNode, Double> buildDatasetTrainerForModel() { + return new DecisionTreeRegressionTrainer(maxDeep, minImpurityDecrease); + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/16a7c980/modules/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/RandomForestTrainer.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/RandomForestTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/RandomForestTrainer.java new file mode 100644 index 0000000..4acf552 --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/RandomForestTrainer.java @@ -0,0 +1,83 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.tree.randomforest; + +import java.util.concurrent.ExecutorService; +import org.apache.ignite.ml.composition.BaggingModelTrainer; +import org.apache.ignite.ml.composition.predictionsaggregator.PredictionsAggregator; + +/** + * Abstract random forest trainer. + */ +public abstract class RandomForestTrainer extends BaggingModelTrainer { + /** Max decision tree deep. */ + protected final int maxDeep; + /** Min impurity decrease. */ + protected final double minImpurityDecrease; + + /** + * Constructs new instance of BaggingModelTrainer. + * + * @param predictionsAggregator Predictions aggregator. + * @param featureVectorSize Feature vector size. + * @param maximumFeaturesCntPerMdl Number of features to draw from original features vector to train each model. + * @param ensembleSize Ensemble size. + * @param samplePartSizePerMdl Size of sample part in percent to train one model. + * @param maxDeep Max decision tree deep. + * @param minImpurityDecrease Min impurity decrease. + */ + public RandomForestTrainer(PredictionsAggregator predictionsAggregator, + int featureVectorSize, + int maximumFeaturesCntPerMdl, + int ensembleSize, + double samplePartSizePerMdl, + int maxDeep, + double minImpurityDecrease) { + + this(predictionsAggregator, featureVectorSize, maximumFeaturesCntPerMdl, + ensembleSize, samplePartSizePerMdl, maxDeep, minImpurityDecrease, null); + } + + /** + * Constructs new instance of BaggingModelTrainer. + * + * @param predictionsAggregator Predictions aggregator. + * @param featureVectorSize Feature vector size. + * @param maximumFeaturesCntPerMdl Number of features to draw from original features vector to train each model. + * @param ensembleSize Ensemble size. + * @param samplePartSizePerMdl Size of sample part in percent to train one model. + * @param maxDeep Max decision tree deep. + * @param minImpurityDecrease Min impurity decrease. + * @param threadPool Learning thread pool. + */ + public RandomForestTrainer(PredictionsAggregator predictionsAggregator, + int featureVectorSize, + int maximumFeaturesCntPerMdl, + int ensembleSize, + double samplePartSizePerMdl, + int maxDeep, + double minImpurityDecrease, + ExecutorService threadPool) { + + super(predictionsAggregator, featureVectorSize, maximumFeaturesCntPerMdl, + ensembleSize, samplePartSizePerMdl, threadPool); + + this.maxDeep = maxDeep; + this.minImpurityDecrease = minImpurityDecrease; + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/16a7c980/modules/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/package-info.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/package-info.java b/modules/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/package-info.java new file mode 100644 index 0000000..abafed2 --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/package-info.java @@ -0,0 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * <!-- Package description. --> + * Contains random forest implementation classes. + */ +package org.apache.ignite.ml.tree.randomforest; http://git-wip-us.apache.org/repos/asf/ignite/blob/16a7c980/modules/ml/src/test/java/org/apache/ignite/ml/composition/predictionsaggregator/MeanValuePredictionsAggregatorTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/composition/predictionsaggregator/MeanValuePredictionsAggregatorTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/composition/predictionsaggregator/MeanValuePredictionsAggregatorTest.java new file mode 100644 index 0000000..d99f4bc --- /dev/null +++ b/modules/ml/src/test/java/org/apache/ignite/ml/composition/predictionsaggregator/MeanValuePredictionsAggregatorTest.java @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.composition.predictionsaggregator; + +import org.junit.Test; + +import static org.junit.Assert.assertEquals; + +public class MeanValuePredictionsAggregatorTest { + private PredictionsAggregator aggregator = new MeanValuePredictionsAggregator(); + + /** */ + @Test public void testApply() { + assertEquals(0.75, aggregator.apply(new double[]{1.0, 1.0, 1.0, 0.0}), 0.001); + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/16a7c980/modules/ml/src/test/java/org/apache/ignite/ml/composition/predictionsaggregator/OnMajorityPredictionsAggregatorTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/composition/predictionsaggregator/OnMajorityPredictionsAggregatorTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/composition/predictionsaggregator/OnMajorityPredictionsAggregatorTest.java new file mode 100644 index 0000000..52055ae --- /dev/null +++ b/modules/ml/src/test/java/org/apache/ignite/ml/composition/predictionsaggregator/OnMajorityPredictionsAggregatorTest.java @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.composition.predictionsaggregator; + +import org.junit.Test; + +import static org.junit.Assert.assertEquals; + +public class OnMajorityPredictionsAggregatorTest { + private PredictionsAggregator aggregator = new OnMajorityPredictionsAggregator(); + + /** */ + @Test public void testApply() { + assertEquals(1.0, aggregator.apply(new double[]{1.0, 1.0, 1.0, 0.0}), 0.001); + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/16a7c980/modules/ml/src/test/java/org/apache/ignite/ml/tree/DecisionTreeTestSuite.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/tree/DecisionTreeTestSuite.java b/modules/ml/src/test/java/org/apache/ignite/ml/tree/DecisionTreeTestSuite.java index 2cbb486..867103e 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/tree/DecisionTreeTestSuite.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/tree/DecisionTreeTestSuite.java @@ -24,6 +24,8 @@ import org.apache.ignite.ml.tree.impurity.mse.MSEImpurityMeasureCalculatorTest; import org.apache.ignite.ml.tree.impurity.mse.MSEImpurityMeasureTest; import org.apache.ignite.ml.tree.impurity.util.SimpleStepFunctionCompressorTest; import org.apache.ignite.ml.tree.impurity.util.StepFunctionTest; +import org.apache.ignite.ml.tree.randomforest.RandomForestClassifierTrainerTest; +import org.apache.ignite.ml.tree.randomforest.RandomForestRegressionTrainerTest; import org.junit.runner.RunWith; import org.junit.runners.Suite; @@ -42,7 +44,9 @@ import org.junit.runners.Suite; MSEImpurityMeasureCalculatorTest.class, MSEImpurityMeasureTest.class, StepFunctionTest.class, - SimpleStepFunctionCompressorTest.class + SimpleStepFunctionCompressorTest.class, + RandomForestClassifierTrainerTest.class, + RandomForestRegressionTrainerTest.class }) public class DecisionTreeTestSuite { } http://git-wip-us.apache.org/repos/asf/ignite/blob/16a7c980/modules/ml/src/test/java/org/apache/ignite/ml/tree/randomforest/RandomForestClassifierTrainerTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/tree/randomforest/RandomForestClassifierTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/tree/randomforest/RandomForestClassifierTrainerTest.java new file mode 100644 index 0000000..d581d6d --- /dev/null +++ b/modules/ml/src/test/java/org/apache/ignite/ml/tree/randomforest/RandomForestClassifierTrainerTest.java @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.tree.randomforest; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import org.apache.ignite.ml.composition.ModelsComposition; +import org.apache.ignite.ml.composition.predictionsaggregator.OnMajorityPredictionsAggregator; +import org.apache.ignite.ml.tree.DecisionTreeConditionalNode; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; + +@RunWith(Parameterized.class) +public class RandomForestClassifierTrainerTest { + /** + * Number of parts to be tested. + */ + private static final int[] partsToBeTested = new int[] {1, 2, 3, 4, 5, 7}; + + /** + * Number of partitions. + */ + @Parameterized.Parameter + public int parts; + + @Parameterized.Parameters(name = "Data divided on {0} partitions") + public static Iterable<Integer[]> data() { + List<Integer[]> res = new ArrayList<>(); + for (int part : partsToBeTested) + res.add(new Integer[] {part}); + + return res; + } + + /** */ + @Test public void testFit() { + int sampleSize = 1000; + Map<double[], Double> sample = new HashMap<>(); + for (int i = 0; i < sampleSize; i++) { + double x1 = i; + double x2 = x1 / 10.0; + double x3 = x2 / 10.0; + double x4 = x3 / 10.0; + + sample.put(new double[] {x1, x2, x3, x4}, (double)(i % 2)); + } + + RandomForestClassifierTrainer trainer = new RandomForestClassifierTrainer(4, 3, 5, 0.3, 4, 0.1); + ModelsComposition model = trainer.fit(sample, parts, (k, v) -> k, (k, v) -> v); + + assertTrue(model.getPredictionsAggregator() instanceof OnMajorityPredictionsAggregator); + assertEquals(5, model.getModels().size()); + + for (ModelsComposition.ModelOnFeaturesSubspace tree : model.getModels()) { + assertTrue(tree.getModel() instanceof DecisionTreeConditionalNode); + assertEquals(3, tree.getFeaturesMapping().size()); + } + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/16a7c980/modules/ml/src/test/java/org/apache/ignite/ml/tree/randomforest/RandomForestRegressionTrainerTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/tree/randomforest/RandomForestRegressionTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/tree/randomforest/RandomForestRegressionTrainerTest.java new file mode 100644 index 0000000..f7594a3 --- /dev/null +++ b/modules/ml/src/test/java/org/apache/ignite/ml/tree/randomforest/RandomForestRegressionTrainerTest.java @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.tree.randomforest; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import org.apache.ignite.ml.composition.ModelsComposition; +import org.apache.ignite.ml.composition.predictionsaggregator.MeanValuePredictionsAggregator; +import org.apache.ignite.ml.tree.DecisionTreeConditionalNode; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; + +@RunWith(Parameterized.class) +public class RandomForestRegressionTrainerTest { + /** + * Number of parts to be tested. + */ + private static final int[] partsToBeTested = new int[] {1, 2, 3, 4, 5, 7}; + + /** + * Number of partitions. + */ + @Parameterized.Parameter + public int parts; + + @Parameterized.Parameters(name = "Data divided on {0} partitions") + public static Iterable<Integer[]> data() { + List<Integer[]> res = new ArrayList<>(); + for (int part : partsToBeTested) + res.add(new Integer[] {part}); + + return res; + } + + /** */ + @Test public void testFit() { + int sampleSize = 1000; + Map<Double, double[]> sample = new HashMap<>(); + for (int i = 0; i < sampleSize; i++) { + double x1 = i; + double x2 = x1 / 10.0; + double x3 = x2 / 10.0; + double x4 = x3 / 10.0; + + sample.put(x1 * x2 + x3 * x4, new double[] {x1, x2, x3, x4}); + } + + RandomForestRegressionTrainer trainer = new RandomForestRegressionTrainer(4, 3, 5, 0.3, 4, 0.1); + ModelsComposition model = trainer.fit(sample, parts, (k, v) -> v, (k, v) -> k); + + assertTrue(model.getPredictionsAggregator() instanceof MeanValuePredictionsAggregator); + assertEquals(5, model.getModels().size()); + + for (ModelsComposition.ModelOnFeaturesSubspace tree : model.getModels()) { + assertTrue(tree.getModel() instanceof DecisionTreeConditionalNode); + assertEquals(3, tree.getFeaturesMapping().size()); + } + } +}