http://git-wip-us.apache.org/repos/asf/ignite/blob/ff6b8eed/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/PreprocessingTrainer.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/PreprocessingTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/PreprocessingTrainer.java index b977864..89751eb 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/PreprocessingTrainer.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/PreprocessingTrainer.java @@ -23,6 +23,7 @@ import org.apache.ignite.IgniteCache; import org.apache.ignite.ml.dataset.DatasetBuilder; import org.apache.ignite.ml.dataset.impl.cache.CacheBasedDatasetBuilder; import org.apache.ignite.ml.dataset.impl.local.LocalDatasetBuilder; +import org.apache.ignite.ml.environment.LearningEnvironmentBuilder; import org.apache.ignite.ml.math.functions.IgniteBiFunction; /** @@ -37,24 +38,61 @@ public interface PreprocessingTrainer<K, V, T, R> { /** * Fits preprocessor. * + * @param envBuilder Learning environment builder. * @param datasetBuilder Dataset builder. * @param basePreprocessor Base preprocessor. * @return Preprocessor. */ - public IgniteBiFunction<K, V, R> fit(DatasetBuilder<K, V> datasetBuilder, + public IgniteBiFunction<K, V, R> fit( + LearningEnvironmentBuilder envBuilder, + DatasetBuilder<K, V> datasetBuilder, IgniteBiFunction<K, V, T> basePreprocessor); /** * Fits preprocessor. * + * @param datasetBuilder Dataset builder. + * @param basePreprocessor Base preprocessor. + * @return Preprocessor. + */ + public default IgniteBiFunction<K, V, R> fit( + DatasetBuilder<K, V> datasetBuilder, + IgniteBiFunction<K, V, T> basePreprocessor) { + return fit(LearningEnvironmentBuilder.defaultBuilder(), datasetBuilder, basePreprocessor); + } + + /** + * Fits preprocessor. + * + * @param ignite Ignite instance. + * @param cache Ignite cache. + * @param basePreprocessor Base preprocessor. + * @return Preprocessor. + */ + public default IgniteBiFunction<K, V, R> fit( + Ignite ignite, IgniteCache<K, V> cache, + IgniteBiFunction<K, V, T> basePreprocessor) { + return fit( + new CacheBasedDatasetBuilder<>(ignite, cache), + basePreprocessor + ); + } + + /** + * Fits preprocessor. + * + * @param envBuilder Learning environment builder. * @param ignite Ignite instance. * @param cache Ignite cache. * @param basePreprocessor Base preprocessor. * @return Preprocessor. */ - public default IgniteBiFunction<K, V, R> fit(Ignite ignite, IgniteCache<K, V> cache, + public default IgniteBiFunction<K, V, R> fit( + LearningEnvironmentBuilder envBuilder, + Ignite ignite, IgniteCache<K, V> cache, IgniteBiFunction<K, V, T> basePreprocessor) { return fit( + envBuilder, new CacheBasedDatasetBuilder<>(ignite, cache), basePreprocessor ); @@ -68,7 +106,29 @@ public interface PreprocessingTrainer<K, V, T, R> { * @param basePreprocessor Base preprocessor. * @return Preprocessor. */ - public default IgniteBiFunction<K, V, R> fit(Map<K, V> data, int parts, + public default IgniteBiFunction<K, V, R> fit( + LearningEnvironmentBuilder envBuilder, + Map<K, V> data, + int parts, + IgniteBiFunction<K, V, T> basePreprocessor) { + return fit( + envBuilder, + new LocalDatasetBuilder<>(data, parts), + basePreprocessor + ); + } + + /** + * Fits preprocessor. + * + * @param data Data. + * @param parts Number of partitions. + * @param basePreprocessor Base preprocessor. + * @return Preprocessor. + */ + public default IgniteBiFunction<K, V, R> fit( + Map<K, V> data, + int parts, IgniteBiFunction<K, V, T> basePreprocessor) { return fit( new LocalDatasetBuilder<>(data, parts),
http://git-wip-us.apache.org/repos/asf/ignite/blob/ff6b8eed/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/binarization/BinarizationTrainer.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/binarization/BinarizationTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/binarization/BinarizationTrainer.java index ad8c90e..039794c 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/binarization/BinarizationTrainer.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/binarization/BinarizationTrainer.java @@ -18,6 +18,7 @@ package org.apache.ignite.ml.preprocessing.binarization; import org.apache.ignite.ml.dataset.DatasetBuilder; +import org.apache.ignite.ml.environment.LearningEnvironmentBuilder; import org.apache.ignite.ml.math.functions.IgniteBiFunction; import org.apache.ignite.ml.math.primitives.vector.Vector; import org.apache.ignite.ml.preprocessing.PreprocessingTrainer; @@ -33,7 +34,9 @@ public class BinarizationTrainer<K, V> implements PreprocessingTrainer<K, V, Vec private double threshold; /** {@inheritDoc} */ - @Override public BinarizationPreprocessor<K, V> fit(DatasetBuilder<K, V> datasetBuilder, + @Override public BinarizationPreprocessor<K, V> fit( + LearningEnvironmentBuilder envBuilder, + DatasetBuilder<K, V> datasetBuilder, IgniteBiFunction<K, V, Vector> basePreprocessor) { return new BinarizationPreprocessor<>(threshold, basePreprocessor); } http://git-wip-us.apache.org/repos/asf/ignite/blob/ff6b8eed/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/encoding/EncoderTrainer.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/encoding/EncoderTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/encoding/EncoderTrainer.java index d5668e4..14a509e 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/encoding/EncoderTrainer.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/encoding/EncoderTrainer.java @@ -29,6 +29,7 @@ import org.apache.ignite.ml.dataset.Dataset; import org.apache.ignite.ml.dataset.DatasetBuilder; import org.apache.ignite.ml.dataset.UpstreamEntry; import org.apache.ignite.ml.dataset.primitive.context.EmptyContext; +import org.apache.ignite.ml.environment.LearningEnvironmentBuilder; import org.apache.ignite.ml.math.functions.IgniteBiFunction; import org.apache.ignite.ml.math.primitives.vector.Vector; import org.apache.ignite.ml.preprocessing.PreprocessingTrainer; @@ -53,14 +54,17 @@ public class EncoderTrainer<K, V> implements PreprocessingTrainer<K, V, Object[] private EncoderSortingStrategy encoderSortingStgy = EncoderSortingStrategy.FREQUENCY_DESC; /** {@inheritDoc} */ - @Override public EncoderPreprocessor<K, V> fit(DatasetBuilder<K, V> datasetBuilder, - IgniteBiFunction<K, V, Object[]> basePreprocessor) { + @Override public EncoderPreprocessor<K, V> fit( + LearningEnvironmentBuilder envBuilder, + DatasetBuilder<K, V> datasetBuilder, + IgniteBiFunction<K, V, Object[]> basePreprocessor) { if (handledIndices.isEmpty()) throw new RuntimeException("Add indices of handled features"); try (Dataset<EmptyContext, EncoderPartitionData> dataset = datasetBuilder.build( - (upstream, upstreamSize) -> new EmptyContext(), - (upstream, upstreamSize, ctx) -> { + envBuilder, + (env, upstream, upstreamSize) -> new EmptyContext(), + (env, upstream, upstreamSize, ctx) -> { // This array will contain not null values for handled indices Map<String, Integer>[] categoryFrequencies = null; http://git-wip-us.apache.org/repos/asf/ignite/blob/ff6b8eed/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/imputing/ImputerTrainer.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/imputing/ImputerTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/imputing/ImputerTrainer.java index 090b0a4..e8920f3 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/imputing/ImputerTrainer.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/imputing/ImputerTrainer.java @@ -23,8 +23,10 @@ import java.util.Map; import java.util.Optional; import org.apache.ignite.ml.dataset.Dataset; import org.apache.ignite.ml.dataset.DatasetBuilder; +import org.apache.ignite.ml.dataset.PartitionContextBuilder; import org.apache.ignite.ml.dataset.UpstreamEntry; import org.apache.ignite.ml.dataset.primitive.context.EmptyContext; +import org.apache.ignite.ml.environment.LearningEnvironmentBuilder; import org.apache.ignite.ml.math.functions.IgniteBiFunction; import org.apache.ignite.ml.math.primitives.vector.Vector; import org.apache.ignite.ml.math.primitives.vector.VectorUtils; @@ -43,11 +45,13 @@ public class ImputerTrainer<K, V> implements PreprocessingTrainer<K, V, Vector, private ImputingStrategy imputingStgy = ImputingStrategy.MEAN; /** {@inheritDoc} */ - @Override public ImputerPreprocessor<K, V> fit(DatasetBuilder<K, V> datasetBuilder, + @Override public ImputerPreprocessor<K, V> fit(LearningEnvironmentBuilder envBuilder, DatasetBuilder<K, V> datasetBuilder, IgniteBiFunction<K, V, Vector> basePreprocessor) { + PartitionContextBuilder<K, V, EmptyContext> builder = (env, upstream, upstreamSize) -> new EmptyContext(); try (Dataset<EmptyContext, ImputerPartitionData> dataset = datasetBuilder.build( - (upstream, upstreamSize) -> new EmptyContext(), - (upstream, upstreamSize, ctx) -> { + envBuilder, + builder, + (env, upstream, upstreamSize, ctx) -> { double[] sums = null; int[] counts = null; Map<Double, Integer>[] valuesByFreq = null; http://git-wip-us.apache.org/repos/asf/ignite/blob/ff6b8eed/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/maxabsscaling/MaxAbsScalerTrainer.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/maxabsscaling/MaxAbsScalerTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/maxabsscaling/MaxAbsScalerTrainer.java index c8b1dca..52acea3 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/maxabsscaling/MaxAbsScalerTrainer.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/maxabsscaling/MaxAbsScalerTrainer.java @@ -21,6 +21,7 @@ import org.apache.ignite.ml.dataset.Dataset; import org.apache.ignite.ml.dataset.DatasetBuilder; import org.apache.ignite.ml.dataset.UpstreamEntry; import org.apache.ignite.ml.dataset.primitive.context.EmptyContext; +import org.apache.ignite.ml.environment.LearningEnvironmentBuilder; import org.apache.ignite.ml.math.functions.IgniteBiFunction; import org.apache.ignite.ml.math.primitives.vector.Vector; import org.apache.ignite.ml.preprocessing.PreprocessingTrainer; @@ -33,11 +34,14 @@ import org.apache.ignite.ml.preprocessing.PreprocessingTrainer; */ public class MaxAbsScalerTrainer<K, V> implements PreprocessingTrainer<K, V, Vector, Vector> { /** {@inheritDoc} */ - @Override public MaxAbsScalerPreprocessor<K, V> fit(DatasetBuilder<K, V> datasetBuilder, + @Override public MaxAbsScalerPreprocessor<K, V> fit( + LearningEnvironmentBuilder envBuilder, + DatasetBuilder<K, V> datasetBuilder, IgniteBiFunction<K, V, Vector> basePreprocessor) { try (Dataset<EmptyContext, MaxAbsScalerPartitionData> dataset = datasetBuilder.build( - (upstream, upstreamSize) -> new EmptyContext(), - (upstream, upstreamSize, ctx) -> { + envBuilder, + (env, upstream, upstreamSize) -> new EmptyContext(), + (env, upstream, upstreamSize, ctx) -> { double[] maxAbs = null; while (upstream.hasNext()) { http://git-wip-us.apache.org/repos/asf/ignite/blob/ff6b8eed/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/minmaxscaling/MinMaxScalerTrainer.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/minmaxscaling/MinMaxScalerTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/minmaxscaling/MinMaxScalerTrainer.java index 6a39236..71f2afc 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/minmaxscaling/MinMaxScalerTrainer.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/minmaxscaling/MinMaxScalerTrainer.java @@ -19,8 +19,10 @@ package org.apache.ignite.ml.preprocessing.minmaxscaling; import org.apache.ignite.ml.dataset.Dataset; import org.apache.ignite.ml.dataset.DatasetBuilder; +import org.apache.ignite.ml.dataset.PartitionContextBuilder; import org.apache.ignite.ml.dataset.UpstreamEntry; import org.apache.ignite.ml.dataset.primitive.context.EmptyContext; +import org.apache.ignite.ml.environment.LearningEnvironmentBuilder; import org.apache.ignite.ml.math.functions.IgniteBiFunction; import org.apache.ignite.ml.math.primitives.vector.Vector; import org.apache.ignite.ml.preprocessing.PreprocessingTrainer; @@ -33,11 +35,15 @@ import org.apache.ignite.ml.preprocessing.PreprocessingTrainer; */ public class MinMaxScalerTrainer<K, V> implements PreprocessingTrainer<K, V, Vector, Vector> { /** {@inheritDoc} */ - @Override public MinMaxScalerPreprocessor<K, V> fit(DatasetBuilder<K, V> datasetBuilder, + @Override public MinMaxScalerPreprocessor<K, V> fit( + LearningEnvironmentBuilder envBuilder, + DatasetBuilder<K, V> datasetBuilder, IgniteBiFunction<K, V, Vector> basePreprocessor) { + PartitionContextBuilder<K, V, EmptyContext> ctxBuilder = (env, upstream, upstreamSize) -> new EmptyContext(); try (Dataset<EmptyContext, MinMaxScalerPartitionData> dataset = datasetBuilder.build( - (upstream, upstreamSize) -> new EmptyContext(), - (upstream, upstreamSize, ctx) -> { + envBuilder, + ctxBuilder, + (env, upstream, upstreamSize, ctx) -> { double[] min = null; double[] max = null; http://git-wip-us.apache.org/repos/asf/ignite/blob/ff6b8eed/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/normalization/NormalizationTrainer.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/normalization/NormalizationTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/normalization/NormalizationTrainer.java index b2dc6ed..08c4a68 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/normalization/NormalizationTrainer.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/normalization/NormalizationTrainer.java @@ -18,6 +18,7 @@ package org.apache.ignite.ml.preprocessing.normalization; import org.apache.ignite.ml.dataset.DatasetBuilder; +import org.apache.ignite.ml.environment.LearningEnvironmentBuilder; import org.apache.ignite.ml.math.functions.IgniteBiFunction; import org.apache.ignite.ml.math.primitives.vector.Vector; import org.apache.ignite.ml.preprocessing.PreprocessingTrainer; @@ -33,7 +34,9 @@ public class NormalizationTrainer<K, V> implements PreprocessingTrainer<K, V, Ve private int p = 2; /** {@inheritDoc} */ - @Override public NormalizationPreprocessor<K, V> fit(DatasetBuilder<K, V> datasetBuilder, + @Override public NormalizationPreprocessor<K, V> fit( + LearningEnvironmentBuilder envBuilder, + DatasetBuilder<K, V> datasetBuilder, IgniteBiFunction<K, V, Vector> basePreprocessor) { return new NormalizationPreprocessor<>(p, basePreprocessor); } http://git-wip-us.apache.org/repos/asf/ignite/blob/ff6b8eed/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/standardscaling/StandardScalerTrainer.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/standardscaling/StandardScalerTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/standardscaling/StandardScalerTrainer.java index 5147b05..604f0b0 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/standardscaling/StandardScalerTrainer.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/standardscaling/StandardScalerTrainer.java @@ -21,6 +21,7 @@ import org.apache.ignite.ml.dataset.Dataset; import org.apache.ignite.ml.dataset.DatasetBuilder; import org.apache.ignite.ml.dataset.UpstreamEntry; import org.apache.ignite.ml.dataset.primitive.context.EmptyContext; +import org.apache.ignite.ml.environment.LearningEnvironmentBuilder; import org.apache.ignite.ml.math.functions.IgniteBiFunction; import org.apache.ignite.ml.math.primitives.vector.Vector; import org.apache.ignite.ml.preprocessing.PreprocessingTrainer; @@ -33,9 +34,10 @@ import org.apache.ignite.ml.preprocessing.PreprocessingTrainer; */ public class StandardScalerTrainer<K, V> implements PreprocessingTrainer<K, V, Vector, Vector> { /** {@inheritDoc} */ - @Override public StandardScalerPreprocessor<K, V> fit(DatasetBuilder<K, V> datasetBuilder, + @Override public StandardScalerPreprocessor<K, V> fit(LearningEnvironmentBuilder envBuilder, + DatasetBuilder<K, V> datasetBuilder, IgniteBiFunction<K, V, Vector> basePreprocessor) { - StandardScalerData standardScalerData = computeSum(datasetBuilder, basePreprocessor); + StandardScalerData standardScalerData = computeSum(envBuilder, datasetBuilder, basePreprocessor); int n = standardScalerData.sum.length; long cnt = standardScalerData.cnt; @@ -51,11 +53,13 @@ public class StandardScalerTrainer<K, V> implements PreprocessingTrainer<K, V, V } /** Computes sum, squared sum and row count. */ - private StandardScalerData computeSum(DatasetBuilder<K, V> datasetBuilder, + private StandardScalerData computeSum(LearningEnvironmentBuilder envBuilder, + DatasetBuilder<K, V> datasetBuilder, IgniteBiFunction<K, V, Vector> basePreprocessor) { try (Dataset<EmptyContext, StandardScalerData> dataset = datasetBuilder.build( - (upstream, upstreamSize) -> new EmptyContext(), - (upstream, upstreamSize, ctx) -> { + envBuilder, + (env, upstream, upstreamSize) -> new EmptyContext(), + (env, upstream, upstreamSize, ctx) -> { double[] sum = null; double[] squaredSum = null; long cnt = 0; http://git-wip-us.apache.org/repos/asf/ignite/blob/ff6b8eed/modules/ml/src/main/java/org/apache/ignite/ml/regressions/linear/LinearRegressionLSQRTrainer.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/regressions/linear/LinearRegressionLSQRTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/regressions/linear/LinearRegressionLSQRTrainer.java index 5497177..dc245d2 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/regressions/linear/LinearRegressionLSQRTrainer.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/regressions/linear/LinearRegressionLSQRTrainer.java @@ -50,6 +50,7 @@ public class LinearRegressionLSQRTrainer extends SingleLabelDatasetTrainer<Linea try (LSQROnHeap<K, V> lsqr = new LSQROnHeap<>( datasetBuilder, + envBuilder, new SimpleLabeledDatasetDataBuilder<>( new FeatureExtractorWrapper<>(featureExtractor), lbExtractor.andThen(e -> new double[] {e}) http://git-wip-us.apache.org/repos/asf/ignite/blob/ff6b8eed/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/multiclass/LogRegressionMultiClassTrainer.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/multiclass/LogRegressionMultiClassTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/multiclass/LogRegressionMultiClassTrainer.java index 71d54fa..fd5a624 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/multiclass/LogRegressionMultiClassTrainer.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/multiclass/LogRegressionMultiClassTrainer.java @@ -136,7 +136,8 @@ public class LogRegressionMultiClassTrainer<P extends Serializable> List<Double> res = new ArrayList<>(); try (Dataset<EmptyContext, LabelPartitionDataOnHeap> dataset = datasetBuilder.build( - (upstream, upstreamSize) -> new EmptyContext(), + envBuilder, + (env, upstream, upstreamSize) -> new EmptyContext(), partDataBuilder )) { final Set<Double> clsLabels = dataset.compute(data -> { http://git-wip-us.apache.org/repos/asf/ignite/blob/ff6b8eed/modules/ml/src/main/java/org/apache/ignite/ml/structures/partition/LabelPartitionDataBuilderOnHeap.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/structures/partition/LabelPartitionDataBuilderOnHeap.java b/modules/ml/src/main/java/org/apache/ignite/ml/structures/partition/LabelPartitionDataBuilderOnHeap.java index 4fba028..5549b08 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/structures/partition/LabelPartitionDataBuilderOnHeap.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/structures/partition/LabelPartitionDataBuilderOnHeap.java @@ -21,6 +21,7 @@ import java.io.Serializable; import java.util.Iterator; import org.apache.ignite.ml.dataset.PartitionDataBuilder; import org.apache.ignite.ml.dataset.UpstreamEntry; +import org.apache.ignite.ml.environment.LearningEnvironment; import org.apache.ignite.ml.math.functions.IgniteBiFunction; /** @@ -48,8 +49,11 @@ public class LabelPartitionDataBuilderOnHeap<K, V, C extends Serializable> } /** {@inheritDoc} */ - @Override public LabelPartitionDataOnHeap build(Iterator<UpstreamEntry<K, V>> upstreamData, long upstreamDataSize, - C ctx) { + @Override public LabelPartitionDataOnHeap build( + LearningEnvironment env, + Iterator<UpstreamEntry<K, V>> upstreamData, + long upstreamDataSize, + C ctx) { double[] y = new double[Math.toIntExact(upstreamDataSize)]; int ptr = 0; http://git-wip-us.apache.org/repos/asf/ignite/blob/ff6b8eed/modules/ml/src/main/java/org/apache/ignite/ml/structures/partition/LabeledDatasetPartitionDataBuilderOnHeap.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/structures/partition/LabeledDatasetPartitionDataBuilderOnHeap.java b/modules/ml/src/main/java/org/apache/ignite/ml/structures/partition/LabeledDatasetPartitionDataBuilderOnHeap.java index 0351037..0d054f6 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/structures/partition/LabeledDatasetPartitionDataBuilderOnHeap.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/structures/partition/LabeledDatasetPartitionDataBuilderOnHeap.java @@ -21,6 +21,7 @@ import java.io.Serializable; import java.util.Iterator; import org.apache.ignite.ml.dataset.PartitionDataBuilder; import org.apache.ignite.ml.dataset.UpstreamEntry; +import org.apache.ignite.ml.environment.LearningEnvironment; import org.apache.ignite.ml.math.functions.IgniteBiFunction; import org.apache.ignite.ml.math.primitives.vector.Vector; import org.apache.ignite.ml.structures.LabeledVector; @@ -57,8 +58,10 @@ public class LabeledDatasetPartitionDataBuilderOnHeap<K, V, C extends Serializab } /** {@inheritDoc} */ - @Override public LabeledVectorSet<Double, LabeledVector> build(Iterator<UpstreamEntry<K, V>> upstreamData, - long upstreamDataSize, C ctx) { + @Override public LabeledVectorSet<Double, LabeledVector> build( + LearningEnvironment env, + Iterator<UpstreamEntry<K, V>> upstreamData, + long upstreamDataSize, C ctx) { int xCols = -1; double[][] x = null; double[] y = new double[Math.toIntExact(upstreamDataSize)]; http://git-wip-us.apache.org/repos/asf/ignite/blob/ff6b8eed/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearBinaryClassificationTrainer.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearBinaryClassificationTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearBinaryClassificationTrainer.java index 47666f4..7ceb53b 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearBinaryClassificationTrainer.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearBinaryClassificationTrainer.java @@ -89,7 +89,8 @@ public class SVMLinearBinaryClassificationTrainer extends SingleLabelDatasetTrai Vector weights; try (Dataset<EmptyContext, LabeledVectorSet<Double, LabeledVector>> dataset = datasetBuilder.build( - (upstream, upstreamSize) -> new EmptyContext(), + envBuilder, + (env, upstream, upstreamSize) -> new EmptyContext(), partDataBuilder )) { if (mdl == null) { http://git-wip-us.apache.org/repos/asf/ignite/blob/ff6b8eed/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearMultiClassClassificationTrainer.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearMultiClassClassificationTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearMultiClassClassificationTrainer.java index b161914..94f2a99 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearMultiClassClassificationTrainer.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearMultiClassClassificationTrainer.java @@ -157,7 +157,8 @@ public class SVMLinearMultiClassClassificationTrainer List<Double> res = new ArrayList<>(); try (Dataset<EmptyContext, LabelPartitionDataOnHeap> dataset = datasetBuilder.build( - (upstream, upstreamSize) -> new EmptyContext(), + envBuilder, + (env, upstream, upstreamSize) -> new EmptyContext(), partDataBuilder )) { final Set<Double> clsLabels = dataset.compute(data -> { http://git-wip-us.apache.org/repos/asf/ignite/blob/ff6b8eed/modules/ml/src/main/java/org/apache/ignite/ml/trainers/DatasetTrainer.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trainers/DatasetTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/trainers/DatasetTrainer.java index f321744..dabf66a 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/trainers/DatasetTrainer.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/trainers/DatasetTrainer.java @@ -26,6 +26,7 @@ import org.apache.ignite.ml.dataset.DatasetBuilder; import org.apache.ignite.ml.dataset.impl.cache.CacheBasedDatasetBuilder; import org.apache.ignite.ml.dataset.impl.local.LocalDatasetBuilder; import org.apache.ignite.ml.environment.LearningEnvironment; +import org.apache.ignite.ml.environment.LearningEnvironmentBuilder; import org.apache.ignite.ml.environment.logging.MLLogger; import org.apache.ignite.ml.math.functions.IgniteBiFunction; import org.apache.ignite.ml.math.primitives.vector.Vector; @@ -38,8 +39,11 @@ import org.jetbrains.annotations.NotNull; * @param <L> Type of a label. */ public abstract class DatasetTrainer<M extends Model, L> { + /** Learning environment builder. */ + protected LearningEnvironmentBuilder envBuilder = LearningEnvironmentBuilder.defaultBuilder(); + /** Learning Environment. */ - protected LearningEnvironment environment = LearningEnvironment.DEFAULT; + protected LearningEnvironment environment = envBuilder.buildForTrainer(); /** * Trains model based on the specified data. @@ -289,11 +293,25 @@ public abstract class DatasetTrainer<M extends Model, L> { } /** - * Sets learning Environment - * @param environment Environment. + * Changes learning Environment. + * + * @param envBuilder Learning environment builder. + */ + // TODO: IGNITE-10441 Think about more elegant ways to perform fluent API. + public DatasetTrainer<M, L> withEnvironmentBuilder(LearningEnvironmentBuilder envBuilder) { + this.envBuilder = envBuilder; + this.environment = envBuilder.buildForTrainer(); + + return this; + } + + /** + * Get learning environment. + * + * @return Learning environment. */ - public void setEnvironment(LearningEnvironment environment) { - this.environment = environment; + public LearningEnvironment learningEnvironment() { + return environment; } /** http://git-wip-us.apache.org/repos/asf/ignite/blob/ff6b8eed/modules/ml/src/main/java/org/apache/ignite/ml/trainers/TrainerTransformers.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trainers/TrainerTransformers.java b/modules/ml/src/main/java/org/apache/ignite/ml/trainers/TrainerTransformers.java index 05504c3..1019a39 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/trainers/TrainerTransformers.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/trainers/TrainerTransformers.java @@ -17,21 +17,15 @@ package org.apache.ignite.ml.trainers; -import java.io.Serializable; import java.util.ArrayList; import java.util.List; import java.util.Random; import java.util.stream.Collectors; import java.util.stream.IntStream; -import org.apache.ignite.lang.IgniteBiPredicate; import org.apache.ignite.ml.Model; import org.apache.ignite.ml.composition.ModelsComposition; import org.apache.ignite.ml.composition.predictionsaggregator.PredictionsAggregator; -import org.apache.ignite.ml.dataset.Dataset; import org.apache.ignite.ml.dataset.DatasetBuilder; -import org.apache.ignite.ml.dataset.PartitionContextBuilder; -import org.apache.ignite.ml.dataset.PartitionDataBuilder; -import org.apache.ignite.ml.dataset.UpstreamTransformerChain; import org.apache.ignite.ml.environment.LearningEnvironment; import org.apache.ignite.ml.environment.logging.MLLogger; import org.apache.ignite.ml.environment.parallelism.Promise; @@ -49,7 +43,7 @@ import org.apache.ignite.ml.util.Utils; */ public class TrainerTransformers { /** - * Add bagging logic to a given trainer. + * Add bagging logic to a given trainer. No features bootstrapping is done. * * @param ensembleSize Size of ensemble. * @param subsampleRatio Subsample ratio to whole dataset. @@ -63,9 +57,8 @@ public class TrainerTransformers { int ensembleSize, double subsampleRatio, PredictionsAggregator aggregator) { - return makeBagged(trainer, ensembleSize, subsampleRatio, -1, -1, aggregator, new Random().nextLong()); + return makeBagged(trainer, ensembleSize, subsampleRatio, -1, -1, aggregator); } - /** * Add bagging logic to a given trainer. * @@ -74,31 +67,23 @@ public class TrainerTransformers { * @param aggregator Aggregator. * @param featureVectorSize Feature vector dimensionality. * @param featuresSubspaceDim Feature subspace dimensionality. - * @param transformationSeed Transformations seed. * @param <M> Type of one model in ensemble. * @param <L> Type of labels. * @return Bagged trainer. */ - // TODO: IGNITE-10296: Inject capabilities of seeding through learning environment (remove). public static <M extends Model<Vector, Double>, L> DatasetTrainer<ModelsComposition, L> makeBagged( DatasetTrainer<M, L> trainer, int ensembleSize, double subsampleRatio, int featureVectorSize, int featuresSubspaceDim, - PredictionsAggregator aggregator, - Long transformationSeed) { + PredictionsAggregator aggregator) { return new DatasetTrainer<ModelsComposition, L>() { /** {@inheritDoc} */ @Override public <K, V> ModelsComposition fit( DatasetBuilder<K, V> datasetBuilder, IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, L> lbExtractor) { - datasetBuilder.upstreamTransformersChain().setSeed( - transformationSeed == null - ? new Random().nextLong() - : transformationSeed); - return runOnEnsemble( (db, i, fe) -> (() -> trainer.fit(db, fe, lbExtractor)), datasetBuilder, @@ -172,21 +157,17 @@ public class TrainerTransformers { log.log(MLLogger.VerboseLevel.LOW, "Start learning."); List<int[]> mappings = null; - if (featuresVectorSize > 0) { + if (featuresVectorSize > 0 && featureSubspaceDim != featuresVectorSize) { mappings = IntStream.range(0, ensembleSize).mapToObj( modelIdx -> getMapping( featuresVectorSize, featureSubspaceDim, - datasetBuilder.upstreamTransformersChain().seed() + modelIdx)) + environment.randomNumbersGenerator().nextLong() + modelIdx)) .collect(Collectors.toList()); } Long startTs = System.currentTimeMillis(); - datasetBuilder - .upstreamTransformersChain() - .addUpstreamTransformer(new BaggingUpstreamTransformer<>(subsampleRatio)); - List<IgniteSupplier<M>> tasks = new ArrayList<>(); List<IgniteBiFunction<K, V, Vector>> extractors = new ArrayList<>(); if (mappings != null) { @@ -195,10 +176,8 @@ public class TrainerTransformers { } for (int i = 0; i < ensembleSize; i++) { - UpstreamTransformerChain<K, V> newChain = Utils.copy(datasetBuilder.upstreamTransformersChain()); - DatasetBuilder<K, V> newBuilder = withNewChain(datasetBuilder, newChain); - int j = i; - newChain.modifySeed(s -> s * s + j); + DatasetBuilder<K, V> newBuilder = + datasetBuilder.withUpstreamTransformer(BaggingUpstreamTransformer.builder(subsampleRatio, i)); tasks.add( trainingTaskGenerator.apply(newBuilder, i, mappings != null ? extractors.get(i) : extractor)); } @@ -338,37 +317,4 @@ public class TrainerTransformers { return mapping; } } - - /** - * Creates new dataset builder which is delegate of a given dataset builder in everything except - * new transformations chain. - * - * @param builder Initial builder. - * @param chain New chain. - * @param <K> Type of keys. - * @param <V> Type of values. - * @return new dataset builder which is delegate of a given dataset builder in everything except - * new transformations chain. - */ - private static <K, V> DatasetBuilder<K, V> withNewChain( - DatasetBuilder<K, V> builder, - UpstreamTransformerChain<K, V> chain) { - return new DatasetBuilder<K, V>() { - /** {@inheritDoc} */ - @Override public <C extends Serializable, D extends AutoCloseable> Dataset<C, D> build( - PartitionContextBuilder<K, V, C> partCtxBuilder, PartitionDataBuilder<K, V, C, D> partDataBuilder) { - return builder.build(partCtxBuilder, partDataBuilder); - } - - /** {@inheritDoc} */ - @Override public UpstreamTransformerChain<K, V> upstreamTransformersChain() { - return chain; - } - - /** {@inheritDoc} */ - @Override public DatasetBuilder<K, V> withFilter(IgniteBiPredicate<K, V> filterToAdd) { - return builder.withFilter(filterToAdd); - } - }; - } } http://git-wip-us.apache.org/repos/asf/ignite/blob/ff6b8eed/modules/ml/src/main/java/org/apache/ignite/ml/trainers/transformers/BaggingUpstreamTransformer.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trainers/transformers/BaggingUpstreamTransformer.java b/modules/ml/src/main/java/org/apache/ignite/ml/trainers/transformers/BaggingUpstreamTransformer.java index f935ebd..7f45fdd 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/trainers/transformers/BaggingUpstreamTransformer.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/trainers/transformers/BaggingUpstreamTransformer.java @@ -17,12 +17,12 @@ package org.apache.ignite.ml.trainers.transformers; -import java.util.Random; import java.util.stream.Stream; import org.apache.commons.math3.distribution.PoissonDistribution; import org.apache.commons.math3.random.Well19937c; import org.apache.ignite.ml.dataset.UpstreamEntry; import org.apache.ignite.ml.dataset.UpstreamTransformer; +import org.apache.ignite.ml.dataset.UpstreamTransformerBuilder; /** * This class encapsulates the logic needed to do bagging (bootstrap aggregating) by features. @@ -33,22 +33,43 @@ import org.apache.ignite.ml.dataset.UpstreamTransformer; * @param <V> Type of upstream values. */ public class BaggingUpstreamTransformer<K, V> implements UpstreamTransformer<K, V> { + /** Serial version uid. */ + private static final long serialVersionUID = -913152523469994149L; + /** Ratio of subsample to entire upstream size */ private double subsampleRatio; + /** Seed used for generating poisson distribution. */ + private long seed; + + /** + * Get builder of {@link BaggingUpstreamTransformer} for a model with a specified index in ensemble. + * + * @param subsampleRatio Subsample ratio. + * @param mdlIdx Index of model in ensemble. + * @param <K> Type of upstream keys. + * @param <V> Type of upstream values. + * @return Builder of {@link BaggingUpstreamTransformer}. + */ + public static <K, V> UpstreamTransformerBuilder<K, V> builder(double subsampleRatio, int mdlIdx) { + return env -> new BaggingUpstreamTransformer<>(env.randomNumbersGenerator().nextLong() + mdlIdx, subsampleRatio); + } + /** * Construct instance of this transformer with a given subsample ratio. * + * @param seed Seed used for generating poisson distribution which in turn used to make subsamples. * @param subsampleRatio Subsample ratio. */ - public BaggingUpstreamTransformer(double subsampleRatio) { + public BaggingUpstreamTransformer(long seed, double subsampleRatio) { this.subsampleRatio = subsampleRatio; + this.seed = seed; } /** {@inheritDoc} */ - @Override public Stream<UpstreamEntry<K, V>> transform(Random rnd, Stream<UpstreamEntry<K, V>> upstream) { + @Override public Stream<UpstreamEntry<K, V>> transform(Stream<UpstreamEntry<K, V>> upstream) { PoissonDistribution poisson = new PoissonDistribution( - new Well19937c(rnd.nextLong()), + new Well19937c(seed), subsampleRatio, PoissonDistribution.DEFAULT_EPSILON, PoissonDistribution.DEFAULT_MAX_ITERATIONS); http://git-wip-us.apache.org/repos/asf/ignite/blob/ff6b8eed/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTree.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTree.java b/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTree.java index 482c938..510d26e 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTree.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTree.java @@ -24,6 +24,7 @@ import org.apache.ignite.ml.dataset.Dataset; import org.apache.ignite.ml.dataset.DatasetBuilder; import org.apache.ignite.ml.dataset.primitive.builder.context.EmptyContextBuilder; import org.apache.ignite.ml.dataset.primitive.context.EmptyContext; +import org.apache.ignite.ml.environment.LearningEnvironmentBuilder; import org.apache.ignite.ml.math.functions.IgniteBiFunction; import org.apache.ignite.ml.math.primitives.vector.Vector; import org.apache.ignite.ml.trainers.DatasetTrainer; @@ -76,6 +77,7 @@ public abstract class DecisionTree<T extends ImpurityMeasure<T>> extends Dataset @Override public <K, V> DecisionTreeNode fit(DatasetBuilder<K, V> datasetBuilder, IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, Double> lbExtractor) { try (Dataset<EmptyContext, DecisionTreeData> dataset = datasetBuilder.build( + envBuilder, new EmptyContextBuilder<>(), new DecisionTreeDataBuilder<>(featureExtractor, lbExtractor, usingIdx) )) { @@ -108,6 +110,11 @@ public abstract class DecisionTree<T extends ImpurityMeasure<T>> extends Dataset return true; } + /** {@inheritDoc} */ + @Override public DecisionTree<T> withEnvironmentBuilder(LearningEnvironmentBuilder envBuilder) { + return (DecisionTree<T>)super.withEnvironmentBuilder(envBuilder); + } + /** */ public <K,V> DecisionTreeNode fit(Dataset<EmptyContext, DecisionTreeData> dataset) { return split(dataset, e -> true, 0, getImpurityMeasureCalculator(dataset)); http://git-wip-us.apache.org/repos/asf/ignite/blob/ff6b8eed/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTreeClassificationTrainer.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTreeClassificationTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTreeClassificationTrainer.java index 58552f4..321e65f 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTreeClassificationTrainer.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTreeClassificationTrainer.java @@ -23,6 +23,7 @@ import java.util.Map; import java.util.Set; import org.apache.ignite.ml.dataset.Dataset; import org.apache.ignite.ml.dataset.primitive.context.EmptyContext; +import org.apache.ignite.ml.environment.LearningEnvironmentBuilder; import org.apache.ignite.ml.tree.data.DecisionTreeData; import org.apache.ignite.ml.tree.impurity.ImpurityMeasureCalculator; import org.apache.ignite.ml.tree.impurity.gini.GiniImpurityMeasure; @@ -129,4 +130,9 @@ public class DecisionTreeClassificationTrainer extends DecisionTree<GiniImpurity return new GiniImpurityMeasureCalculator(encoder, usingIdx); } + + /** {@inheritDoc} */ + @Override public DecisionTreeClassificationTrainer withEnvironmentBuilder(LearningEnvironmentBuilder envBuilder) { + return (DecisionTreeClassificationTrainer)super.withEnvironmentBuilder(envBuilder); + } } http://git-wip-us.apache.org/repos/asf/ignite/blob/ff6b8eed/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTreeRegressionTrainer.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTreeRegressionTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTreeRegressionTrainer.java index ea57bcc..2b259f2 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTreeRegressionTrainer.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTreeRegressionTrainer.java @@ -19,6 +19,7 @@ package org.apache.ignite.ml.tree; import org.apache.ignite.ml.dataset.Dataset; import org.apache.ignite.ml.dataset.primitive.context.EmptyContext; +import org.apache.ignite.ml.environment.LearningEnvironmentBuilder; import org.apache.ignite.ml.tree.data.DecisionTreeData; import org.apache.ignite.ml.tree.impurity.ImpurityMeasureCalculator; import org.apache.ignite.ml.tree.impurity.mse.MSEImpurityMeasure; @@ -69,4 +70,9 @@ public class DecisionTreeRegressionTrainer extends DecisionTree<MSEImpurityMeasu return new MSEImpurityMeasureCalculator(usingIdx); } + + /** {@inheritDoc} */ + @Override public DecisionTreeRegressionTrainer withEnvironmentBuilder(LearningEnvironmentBuilder envBuilder) { + return (DecisionTreeRegressionTrainer)super.withEnvironmentBuilder(envBuilder); + } } http://git-wip-us.apache.org/repos/asf/ignite/blob/ff6b8eed/modules/ml/src/main/java/org/apache/ignite/ml/tree/boosting/GDBBinaryClassifierOnTreesTrainer.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/tree/boosting/GDBBinaryClassifierOnTreesTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/tree/boosting/GDBBinaryClassifierOnTreesTrainer.java index b99dc2f..b19652d 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/tree/boosting/GDBBinaryClassifierOnTreesTrainer.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/tree/boosting/GDBBinaryClassifierOnTreesTrainer.java @@ -19,6 +19,7 @@ package org.apache.ignite.ml.tree.boosting; import org.apache.ignite.ml.composition.boosting.GDBBinaryClassifierTrainer; import org.apache.ignite.ml.composition.boosting.GDBLearningStrategy; +import org.apache.ignite.ml.environment.LearningEnvironmentBuilder; import org.apache.ignite.ml.tree.DecisionTreeRegressionTrainer; import org.jetbrains.annotations.NotNull; @@ -61,6 +62,11 @@ public class GDBBinaryClassifierOnTreesTrainer extends GDBBinaryClassifierTraine return new GDBOnTreesLearningStrategy(usingIdx); } + /** {@inheritDoc} */ + @Override public GDBBinaryClassifierOnTreesTrainer withEnvironmentBuilder(LearningEnvironmentBuilder envBuilder) { + return (GDBBinaryClassifierOnTreesTrainer)super.withEnvironmentBuilder(envBuilder); + } + /** * Set useIndex parameter and returns trainer instance. * http://git-wip-us.apache.org/repos/asf/ignite/blob/ff6b8eed/modules/ml/src/main/java/org/apache/ignite/ml/tree/boosting/GDBOnTreesLearningStrategy.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/tree/boosting/GDBOnTreesLearningStrategy.java b/modules/ml/src/main/java/org/apache/ignite/ml/tree/boosting/GDBOnTreesLearningStrategy.java index caac168..71e840c 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/tree/boosting/GDBOnTreesLearningStrategy.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/tree/boosting/GDBOnTreesLearningStrategy.java @@ -70,6 +70,7 @@ public class GDBOnTreesLearningStrategy extends GDBLearningStrategy { externalLbToInternalMapping, loss, datasetBuilder, featureExtractor, lbExtractor); try (Dataset<EmptyContext, DecisionTreeData> dataset = datasetBuilder.build( + envBuilder, new EmptyContextBuilder<>(), new DecisionTreeDataBuilder<>(featureExtractor, lbExtractor, useIdx) )) { @@ -95,7 +96,7 @@ public class GDBOnTreesLearningStrategy extends GDBLearningStrategy { long startTs = System.currentTimeMillis(); models.add(decisionTreeTrainer.fit(dataset)); double learningTime = (double)(System.currentTimeMillis() - startTs) / 1000.0; - environment.logger(getClass()).log(MLLogger.VerboseLevel.LOW, "One model training time was %.2fs", learningTime); + trainerEnvironment.logger(getClass()).log(MLLogger.VerboseLevel.LOW, "One model training time was %.2fs", learningTime); } } catch (Exception e) { http://git-wip-us.apache.org/repos/asf/ignite/blob/ff6b8eed/modules/ml/src/main/java/org/apache/ignite/ml/tree/boosting/GDBRegressionOnTreesTrainer.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/tree/boosting/GDBRegressionOnTreesTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/tree/boosting/GDBRegressionOnTreesTrainer.java index b6c0b48..9c588ce 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/tree/boosting/GDBRegressionOnTreesTrainer.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/tree/boosting/GDBRegressionOnTreesTrainer.java @@ -19,6 +19,7 @@ package org.apache.ignite.ml.tree.boosting; import org.apache.ignite.ml.composition.boosting.GDBLearningStrategy; import org.apache.ignite.ml.composition.boosting.GDBRegressionTrainer; +import org.apache.ignite.ml.environment.LearningEnvironmentBuilder; import org.apache.ignite.ml.tree.DecisionTreeRegressionTrainer; import org.jetbrains.annotations.NotNull; @@ -120,4 +121,9 @@ public class GDBRegressionOnTreesTrainer extends GDBRegressionTrainer { @Override protected GDBLearningStrategy getLearningStrategy() { return new GDBOnTreesLearningStrategy(usingIdx); } + + /** {@inheritDoc} */ + @Override public GDBRegressionOnTreesTrainer withEnvironmentBuilder(LearningEnvironmentBuilder envBuilder) { + return (GDBRegressionOnTreesTrainer)super.withEnvironmentBuilder(envBuilder); + } } http://git-wip-us.apache.org/repos/asf/ignite/blob/ff6b8eed/modules/ml/src/main/java/org/apache/ignite/ml/tree/data/DecisionTreeDataBuilder.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/tree/data/DecisionTreeDataBuilder.java b/modules/ml/src/main/java/org/apache/ignite/ml/tree/data/DecisionTreeDataBuilder.java index 4436b07..1378120 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/tree/data/DecisionTreeDataBuilder.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/tree/data/DecisionTreeDataBuilder.java @@ -21,6 +21,7 @@ import java.io.Serializable; import java.util.Iterator; import org.apache.ignite.ml.dataset.PartitionDataBuilder; import org.apache.ignite.ml.dataset.UpstreamEntry; +import org.apache.ignite.ml.environment.LearningEnvironment; import org.apache.ignite.ml.math.functions.IgniteBiFunction; import org.apache.ignite.ml.math.primitives.vector.Vector; @@ -60,7 +61,11 @@ public class DecisionTreeDataBuilder<K, V, C extends Serializable> } /** {@inheritDoc} */ - @Override public DecisionTreeData build(Iterator<UpstreamEntry<K, V>> upstreamData, long upstreamDataSize, C ctx) { + @Override public DecisionTreeData build( + LearningEnvironment envBuilder, + Iterator<UpstreamEntry<K, V>> upstreamData, + long upstreamDataSize, + C ctx) { double[][] features = new double[Math.toIntExact(upstreamDataSize)][]; double[] labels = new double[Math.toIntExact(upstreamDataSize)]; http://git-wip-us.apache.org/repos/asf/ignite/blob/ff6b8eed/modules/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/RandomForestTrainer.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/RandomForestTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/RandomForestTrainer.java index 72a97c4..3ee90cb 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/RandomForestTrainer.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/RandomForestTrainer.java @@ -114,6 +114,7 @@ public abstract class RandomForestTrainer<L, S extends ImpurityComputer<Bootstra IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, Double> lbExtractor) { List<TreeRoot> models = null; try (Dataset<EmptyContext, BootstrappedDatasetPartition> dataset = datasetBuilder.build( + envBuilder, new EmptyContextBuilder<>(), new BootstrappedDatasetBuilder<>(featureExtractor, lbExtractor, amountOfTrees, subSampleSize))) { http://git-wip-us.apache.org/repos/asf/ignite/blob/ff6b8eed/modules/ml/src/test/java/org/apache/ignite/ml/TestUtils.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/TestUtils.java b/modules/ml/src/test/java/org/apache/ignite/ml/TestUtils.java index 4b472cc..1103ef0 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/TestUtils.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/TestUtils.java @@ -18,6 +18,7 @@ package org.apache.ignite.ml; import java.util.stream.IntStream; +import org.apache.ignite.ml.environment.LearningEnvironmentBuilder; import org.apache.ignite.ml.math.primitives.matrix.Matrix; import org.apache.ignite.ml.math.primitives.vector.Vector; import org.junit.Assert; @@ -325,4 +326,82 @@ public class TestUtils { } } + + /** + * Gets test learning environment builder. + * + * @return test learning environment builder. + */ + public static LearningEnvironmentBuilder testEnvBuilder() { + return testEnvBuilder(123L); + } + + /** + * Gets test learning environment builder with a given seed. + * + * @param seed Seed. + * @return test learning environment builder. + */ + public static LearningEnvironmentBuilder testEnvBuilder(long seed) { + return LearningEnvironmentBuilder.defaultBuilder().withRNGSeed(seed); + } + + /** + * Simple wrapper class which adds {@link AutoCloseable} to given type. + * + * @param <T> Type to wrap. + */ + public static class DataWrapper<T> implements AutoCloseable { + /** + * Value to wrap. + */ + T val; + + /** + * Wrap given value in {@link AutoCloseable}. + * + * @param val Value to wrap. + * @param <T> Type of value to wrap. + * @return Value wrapped as {@link AutoCloseable}. + */ + public static <T> DataWrapper<T> of(T val) { + return new DataWrapper<>(val); + } + + /** + * Construct instance of this class from given value. + * + * @param val Value to wrap. + */ + public DataWrapper(T val) { + this.val = val; + } + + /** + * Get wrapped value. + * + * @return Wrapped value. + */ + public T val() { + return val; + } + + /** {@inheritDoc} */ + @Override public void close() throws Exception { + if (val instanceof AutoCloseable) + ((AutoCloseable)val).close(); + } + } + + /** + * Return model which returns given constant. + * + * @param v Constant value. + * @param <T> Type of input. + * @param <V> Type of output. + * @return Model which returns given constant. + */ + public static <T, V> Model<T, V> constantModel(V v) { + return t -> v; + } } http://git-wip-us.apache.org/repos/asf/ignite/blob/ff6b8eed/modules/ml/src/test/java/org/apache/ignite/ml/composition/boosting/convergence/mean/MeanAbsValueConvergenceCheckerTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/composition/boosting/convergence/mean/MeanAbsValueConvergenceCheckerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/composition/boosting/convergence/mean/MeanAbsValueConvergenceCheckerTest.java index 0b42db8..c218a74 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/composition/boosting/convergence/mean/MeanAbsValueConvergenceCheckerTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/composition/boosting/convergence/mean/MeanAbsValueConvergenceCheckerTest.java @@ -17,6 +17,7 @@ package org.apache.ignite.ml.composition.boosting.convergence.mean; +import org.apache.ignite.ml.TestUtils; import org.apache.ignite.ml.composition.boosting.convergence.ConvergenceChecker; import org.apache.ignite.ml.composition.boosting.convergence.ConvergenceCheckerTest; import org.apache.ignite.ml.dataset.impl.local.LocalDataset; @@ -25,6 +26,7 @@ import org.apache.ignite.ml.dataset.primitive.FeatureMatrixWithLabelsOnHeapData; import org.apache.ignite.ml.dataset.primitive.FeatureMatrixWithLabelsOnHeapDataBuilder; import org.apache.ignite.ml.dataset.primitive.builder.context.EmptyContextBuilder; import org.apache.ignite.ml.dataset.primitive.context.EmptyContext; +import org.apache.ignite.ml.environment.LearningEnvironmentBuilder; import org.apache.ignite.ml.math.primitives.vector.VectorUtils; import org.junit.Assert; import org.junit.Test; @@ -39,11 +41,14 @@ public class MeanAbsValueConvergenceCheckerTest extends ConvergenceCheckerTest { new MeanAbsValueConvergenceCheckerFactory(0.1), datasetBuilder); double error = checker.computeError(VectorUtils.of(1, 2), 4.0, notConvergedMdl); + LearningEnvironmentBuilder envBuilder = TestUtils.testEnvBuilder(); + Assert.assertEquals(1.9, error, 0.01); - Assert.assertFalse(checker.isConverged(datasetBuilder, notConvergedMdl)); - Assert.assertTrue(checker.isConverged(datasetBuilder, convergedMdl)); + Assert.assertFalse(checker.isConverged(envBuilder, datasetBuilder, notConvergedMdl)); + Assert.assertTrue(checker.isConverged(envBuilder, datasetBuilder, convergedMdl)); try(LocalDataset<EmptyContext, FeatureMatrixWithLabelsOnHeapData> dataset = datasetBuilder.build( + envBuilder, new EmptyContextBuilder<>(), new FeatureMatrixWithLabelsOnHeapDataBuilder<>(fExtr, lbExtr))) { double onDSError = checker.computeMeanErrorOnDataset(dataset, notConvergedMdl); @@ -62,6 +67,7 @@ public class MeanAbsValueConvergenceCheckerTest extends ConvergenceCheckerTest { new MeanAbsValueConvergenceCheckerFactory(0.1), datasetBuilder); try(LocalDataset<EmptyContext, FeatureMatrixWithLabelsOnHeapData> dataset = datasetBuilder.build( + TestUtils.testEnvBuilder(), new EmptyContextBuilder<>(), new FeatureMatrixWithLabelsOnHeapDataBuilder<>(fExtr, lbExtr))) { double onDSError = checker.computeMeanErrorOnDataset(dataset, notConvergedMdl); http://git-wip-us.apache.org/repos/asf/ignite/blob/ff6b8eed/modules/ml/src/test/java/org/apache/ignite/ml/composition/boosting/convergence/median/MedianOfMedianConvergenceCheckerTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/composition/boosting/convergence/median/MedianOfMedianConvergenceCheckerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/composition/boosting/convergence/median/MedianOfMedianConvergenceCheckerTest.java index d6880b4..0476a37 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/composition/boosting/convergence/median/MedianOfMedianConvergenceCheckerTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/composition/boosting/convergence/median/MedianOfMedianConvergenceCheckerTest.java @@ -17,6 +17,7 @@ package org.apache.ignite.ml.composition.boosting.convergence.median; +import org.apache.ignite.ml.TestUtils; import org.apache.ignite.ml.composition.boosting.convergence.ConvergenceChecker; import org.apache.ignite.ml.composition.boosting.convergence.ConvergenceCheckerTest; import org.apache.ignite.ml.dataset.impl.local.LocalDataset; @@ -25,6 +26,7 @@ import org.apache.ignite.ml.dataset.primitive.FeatureMatrixWithLabelsOnHeapData; import org.apache.ignite.ml.dataset.primitive.FeatureMatrixWithLabelsOnHeapDataBuilder; import org.apache.ignite.ml.dataset.primitive.builder.context.EmptyContextBuilder; import org.apache.ignite.ml.dataset.primitive.context.EmptyContext; +import org.apache.ignite.ml.environment.LearningEnvironmentBuilder; import org.apache.ignite.ml.math.primitives.vector.VectorUtils; import org.junit.Assert; import org.junit.Test; @@ -42,10 +44,14 @@ public class MedianOfMedianConvergenceCheckerTest extends ConvergenceCheckerTest double error = checker.computeError(VectorUtils.of(1, 2), 4.0, notConvergedMdl); Assert.assertEquals(1.9, error, 0.01); - Assert.assertFalse(checker.isConverged(datasetBuilder, notConvergedMdl)); - Assert.assertTrue(checker.isConverged(datasetBuilder, convergedMdl)); + + LearningEnvironmentBuilder envBuilder = TestUtils.testEnvBuilder(); + + Assert.assertFalse(checker.isConverged(envBuilder, datasetBuilder, notConvergedMdl)); + Assert.assertTrue(checker.isConverged(envBuilder, datasetBuilder, convergedMdl)); try(LocalDataset<EmptyContext, FeatureMatrixWithLabelsOnHeapData> dataset = datasetBuilder.build( + envBuilder, new EmptyContextBuilder<>(), new FeatureMatrixWithLabelsOnHeapDataBuilder<>(fExtr, lbExtr))) { double onDSError = checker.computeMeanErrorOnDataset(dataset, notConvergedMdl); http://git-wip-us.apache.org/repos/asf/ignite/blob/ff6b8eed/modules/ml/src/test/java/org/apache/ignite/ml/dataset/impl/cache/CacheBasedDatasetBuilderTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/dataset/impl/cache/CacheBasedDatasetBuilderTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/dataset/impl/cache/CacheBasedDatasetBuilderTest.java index 1cf6dbf..815bd86 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/dataset/impl/cache/CacheBasedDatasetBuilderTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/dataset/impl/cache/CacheBasedDatasetBuilderTest.java @@ -26,6 +26,7 @@ import org.apache.ignite.cache.affinity.rendezvous.RendezvousAffinityFunction; import org.apache.ignite.cluster.ClusterNode; import org.apache.ignite.configuration.CacheConfiguration; import org.apache.ignite.internal.util.IgniteUtils; +import org.apache.ignite.ml.TestUtils; import org.apache.ignite.ml.dataset.UpstreamEntry; import org.apache.ignite.testframework.junits.common.GridCommonAbstractTest; @@ -66,8 +67,9 @@ public class CacheBasedDatasetBuilderTest extends GridCommonAbstractTest { CacheBasedDatasetBuilder<Integer, String> builder = new CacheBasedDatasetBuilder<>(ignite, upstreamCache); CacheBasedDataset<Integer, String, Long, AutoCloseable> dataset = builder.build( - (upstream, upstreamSize) -> upstreamSize, - (upstream, upstreamSize, ctx) -> null + TestUtils.testEnvBuilder(), + (env, upstream, upstreamSize) -> upstreamSize, + (env, upstream, upstreamSize, ctx) -> null ); Affinity<Integer> upstreamAffinity = ignite.affinity(upstreamCache.getName()); @@ -105,14 +107,15 @@ public class CacheBasedDatasetBuilderTest extends GridCommonAbstractTest { ); CacheBasedDataset<Integer, Integer, Long, AutoCloseable> dataset = builder.build( - (upstream, upstreamSize) -> { + TestUtils.testEnvBuilder(), + (env, upstream, upstreamSize) -> { UpstreamEntry<Integer, Integer> entry = upstream.next(); assertEquals(Integer.valueOf(2), entry.getKey()); assertEquals(Integer.valueOf(2), entry.getValue()); assertFalse(upstream.hasNext()); return 0L; }, - (upstream, upstreamSize, ctx) -> { + (env, upstream, upstreamSize, ctx) -> { UpstreamEntry<Integer, Integer> entry = upstream.next(); assertEquals(Integer.valueOf(2), entry.getKey()); assertEquals(Integer.valueOf(2), entry.getValue()); http://git-wip-us.apache.org/repos/asf/ignite/blob/ff6b8eed/modules/ml/src/test/java/org/apache/ignite/ml/dataset/impl/cache/CacheBasedDatasetTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/dataset/impl/cache/CacheBasedDatasetTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/dataset/impl/cache/CacheBasedDatasetTest.java index a892530..7e31b07 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/dataset/impl/cache/CacheBasedDatasetTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/dataset/impl/cache/CacheBasedDatasetTest.java @@ -38,6 +38,7 @@ import org.apache.ignite.internal.processors.cache.distributed.dht.topology.Grid import org.apache.ignite.internal.util.IgniteUtils; import org.apache.ignite.internal.util.typedef.G; import org.apache.ignite.lang.IgnitePredicate; +import org.apache.ignite.ml.TestUtils; import org.apache.ignite.ml.dataset.primitive.data.SimpleDatasetData; import org.apache.ignite.testframework.junits.common.GridCommonAbstractTest; @@ -83,8 +84,9 @@ public class CacheBasedDatasetTest extends GridCommonAbstractTest { CacheBasedDatasetBuilder<Integer, String> builder = new CacheBasedDatasetBuilder<>(ignite, upstreamCache); CacheBasedDataset<Integer, String, Long, SimpleDatasetData> dataset = builder.build( - (upstream, upstreamSize) -> upstreamSize, - (upstream, upstreamSize, ctx) -> new SimpleDatasetData(new double[0], 0) + TestUtils.testEnvBuilder(), + (env, upstream, upstreamSize) -> upstreamSize, + (env, upstream, upstreamSize, ctx) -> new SimpleDatasetData(new double[0], 0) ); assertEquals("Upstream cache name from dataset", @@ -138,8 +140,9 @@ public class CacheBasedDatasetTest extends GridCommonAbstractTest { CacheBasedDatasetBuilder<Integer, String> builder = new CacheBasedDatasetBuilder<>(ignite, upstreamCache); CacheBasedDataset<Integer, String, Long, SimpleDatasetData> dataset = builder.build( - (upstream, upstreamSize) -> upstreamSize, - (upstream, upstreamSize, ctx) -> new SimpleDatasetData(new double[0], 0) + TestUtils.testEnvBuilder(), + (env, upstream, upstreamSize) -> upstreamSize, + (env, upstream, upstreamSize, ctx) -> new SimpleDatasetData(new double[0], 0) ); assertTrue("Before computation all partitions should not be reserved", http://git-wip-us.apache.org/repos/asf/ignite/blob/ff6b8eed/modules/ml/src/test/java/org/apache/ignite/ml/dataset/impl/cache/util/ComputeUtilsTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/dataset/impl/cache/util/ComputeUtilsTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/dataset/impl/cache/util/ComputeUtilsTest.java index cee8f4f..202b6bc 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/dataset/impl/cache/util/ComputeUtilsTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/dataset/impl/cache/util/ComputeUtilsTest.java @@ -32,8 +32,9 @@ import org.apache.ignite.cache.affinity.AffinityFunctionContext; import org.apache.ignite.cluster.ClusterNode; import org.apache.ignite.configuration.CacheConfiguration; import org.apache.ignite.internal.util.IgniteUtils; +import org.apache.ignite.ml.TestUtils; import org.apache.ignite.ml.dataset.UpstreamEntry; -import org.apache.ignite.ml.dataset.UpstreamTransformerChain; +import org.apache.ignite.ml.dataset.UpstreamTransformerBuilder; import org.apache.ignite.testframework.junits.common.GridCommonAbstractTest; /** @@ -179,18 +180,18 @@ public class ComputeUtilsTest extends GridCommonAbstractTest { ignite, upstreamCacheName, (k, v) -> true, - UpstreamTransformerChain.empty(), + UpstreamTransformerBuilder.identity(), datasetCacheName, datasetId, - 0, - (upstream, upstreamSize, ctx) -> { + (env, upstream, upstreamSize, ctx) -> { cnt.incrementAndGet(); assertEquals(1, upstreamSize); UpstreamEntry<Integer, Integer> e = upstream.next(); return new TestPartitionData(e.getKey() + e.getValue()); - } + }, + TestUtils.testEnvBuilder().buildForWorker(part) ), 0 ); @@ -229,15 +230,16 @@ public class ComputeUtilsTest extends GridCommonAbstractTest { ignite, upstreamCacheName, (k, v) -> true, - UpstreamTransformerChain.empty(), + UpstreamTransformerBuilder.identity(), datasetCacheName, - (upstream, upstreamSize) -> { + (env, upstream, upstreamSize) -> { assertEquals(1, upstreamSize); UpstreamEntry<Integer, Integer> e = upstream.next(); return e.getKey() + e.getValue(); }, + TestUtils.testEnvBuilder(), 0 ); http://git-wip-us.apache.org/repos/asf/ignite/blob/ff6b8eed/modules/ml/src/test/java/org/apache/ignite/ml/dataset/impl/local/LocalDatasetBuilderTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/dataset/impl/local/LocalDatasetBuilderTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/dataset/impl/local/LocalDatasetBuilderTest.java index 8dc9354..6088140 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/dataset/impl/local/LocalDatasetBuilderTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/dataset/impl/local/LocalDatasetBuilderTest.java @@ -21,6 +21,7 @@ import java.io.Serializable; import java.util.HashMap; import java.util.Map; import java.util.concurrent.atomic.AtomicLong; +import org.apache.ignite.ml.TestUtils; import org.apache.ignite.ml.dataset.PartitionContextBuilder; import org.apache.ignite.ml.dataset.PartitionDataBuilder; import org.junit.Test; @@ -47,7 +48,7 @@ public class LocalDatasetBuilderTest { AtomicLong cnt = new AtomicLong(); - dataset.compute((partData, partIdx) -> { + dataset.compute((partData, env) -> { cnt.incrementAndGet(); int[] arr = partData.data; @@ -55,7 +56,7 @@ public class LocalDatasetBuilderTest { assertEquals(10, arr.length); for (int i = 0; i < 10; i++) - assertEquals(partIdx * 10 + i, arr[i]); + assertEquals(env.partition() * 10 + i, arr[i]); }); assertEquals(10, cnt.intValue()); @@ -74,7 +75,7 @@ public class LocalDatasetBuilderTest { AtomicLong cnt = new AtomicLong(); - dataset.compute((partData, partIdx) -> { + dataset.compute((partData, env) -> { cnt.incrementAndGet(); int[] arr = partData.data; @@ -82,7 +83,7 @@ public class LocalDatasetBuilderTest { assertEquals(5, arr.length); for (int i = 0; i < 5; i++) - assertEquals((partIdx * 5 + i) * 2, arr[i]); + assertEquals((env.partition() * 5 + i) * 2, arr[i]); }); assertEquals(10, cnt.intValue()); @@ -91,10 +92,10 @@ public class LocalDatasetBuilderTest { /** */ private LocalDataset<Serializable, TestPartitionData> buildDataset( LocalDatasetBuilder<Integer, Integer> builder) { - PartitionContextBuilder<Integer, Integer, Serializable> partCtxBuilder = (upstream, upstreamSize) -> null; + PartitionContextBuilder<Integer, Integer, Serializable> partCtxBuilder = (env, upstream, upstreamSize) -> null; PartitionDataBuilder<Integer, Integer, Serializable, TestPartitionData> partDataBuilder - = (upstream, upstreamSize, ctx) -> { + = (env, upstream, upstreamSize, ctx) -> { int[] arr = new int[Math.toIntExact(upstreamSize)]; int ptr = 0; @@ -105,6 +106,7 @@ public class LocalDatasetBuilderTest { }; return builder.build( + TestUtils.testEnvBuilder(), partCtxBuilder.andThen(x -> null), partDataBuilder.andThen((x, y) -> x) ); http://git-wip-us.apache.org/repos/asf/ignite/blob/ff6b8eed/modules/ml/src/test/java/org/apache/ignite/ml/dataset/primitive/SimpleDatasetTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/dataset/primitive/SimpleDatasetTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/dataset/primitive/SimpleDatasetTest.java index eaa03d2..33c0677 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/dataset/primitive/SimpleDatasetTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/dataset/primitive/SimpleDatasetTest.java @@ -19,6 +19,7 @@ package org.apache.ignite.ml.dataset.primitive; import java.util.HashMap; import java.util.Map; +import org.apache.ignite.ml.TestUtils; import org.apache.ignite.ml.dataset.DatasetFactory; import org.apache.ignite.ml.math.primitives.vector.VectorUtils; import org.junit.Test; @@ -43,6 +44,7 @@ public class SimpleDatasetTest { try (SimpleDataset<?> dataset = DatasetFactory.createSimpleDataset( dataPoints, 2, + TestUtils.testEnvBuilder(), (k, v) -> VectorUtils.of(v.getAge(), v.getSalary()) )) { assertArrayEquals("Mean values.", new double[] {37.75, 66000.0}, dataset.mean(), 0); http://git-wip-us.apache.org/repos/asf/ignite/blob/ff6b8eed/modules/ml/src/test/java/org/apache/ignite/ml/dataset/primitive/SimpleLabeledDatasetTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/dataset/primitive/SimpleLabeledDatasetTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/dataset/primitive/SimpleLabeledDatasetTest.java index f7b0f13..36e540b 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/dataset/primitive/SimpleLabeledDatasetTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/dataset/primitive/SimpleLabeledDatasetTest.java @@ -19,6 +19,7 @@ package org.apache.ignite.ml.dataset.primitive; import java.util.HashMap; import java.util.Map; +import org.apache.ignite.ml.TestUtils; import org.apache.ignite.ml.dataset.DatasetFactory; import org.apache.ignite.ml.math.primitives.vector.VectorUtils; import org.junit.Test; @@ -47,14 +48,16 @@ public class SimpleLabeledDatasetTest { // Creates a local simple dataset containing features and providing standard dataset API. try (SimpleLabeledDataset<?> dataset = DatasetFactory.createSimpleLabeledDataset( dataPoints, + TestUtils.testEnvBuilder(), 2, (k, v) -> VectorUtils.of(v.getAge(), v.getSalary()), (k, v) -> new double[] {k, v.getAge(), v.getSalary()} )) { - assertNull(dataset.compute((data, partIdx) -> { - actualFeatures[partIdx] = data.getFeatures(); - actualLabels[partIdx] = data.getLabels(); - actualRows[partIdx] = data.getRows(); + assertNull(dataset.compute((data, env) -> { + int part = env.partition(); + actualFeatures[part] = data.getFeatures(); + actualLabels[part] = data.getLabels(); + actualRows[part] = data.getRows(); return null; }, (k, v) -> null)); } http://git-wip-us.apache.org/repos/asf/ignite/blob/ff6b8eed/modules/ml/src/test/java/org/apache/ignite/ml/environment/LearningEnvironmentBuilderTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/environment/LearningEnvironmentBuilderTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/environment/LearningEnvironmentBuilderTest.java index 56f262b..7769092 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/environment/LearningEnvironmentBuilderTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/environment/LearningEnvironmentBuilderTest.java @@ -39,7 +39,7 @@ public class LearningEnvironmentBuilderTest { /** */ @Test public void basic() { - LearningEnvironment env = LearningEnvironment.DEFAULT; + LearningEnvironment env = LearningEnvironment.DEFAULT_TRAINER_ENV; assertNotNull("Strategy", env.parallelismStrategy()); assertNotNull("Logger", env.logger()); @@ -49,42 +49,44 @@ public class LearningEnvironmentBuilderTest { /** */ @Test public void withParallelismStrategy() { - assertTrue(LearningEnvironment.builder().withParallelismStrategy(NoParallelismStrategy.INSTANCE).build() + assertTrue(LearningEnvironmentBuilder.defaultBuilder().withParallelismStrategyDependency(part -> NoParallelismStrategy.INSTANCE) + .buildForTrainer() .parallelismStrategy() instanceof NoParallelismStrategy); - assertTrue(LearningEnvironment.builder().withParallelismStrategy(new DefaultParallelismStrategy()).build() + assertTrue(LearningEnvironmentBuilder.defaultBuilder().withParallelismStrategyDependency(part -> new DefaultParallelismStrategy()) + .buildForTrainer() .parallelismStrategy() instanceof DefaultParallelismStrategy); } /** */ @Test public void withParallelismStrategyType() { - assertTrue(LearningEnvironment.builder().withParallelismStrategy(NO_PARALLELISM).build() + assertTrue(LearningEnvironmentBuilder.defaultBuilder().withParallelismStrategyType(NO_PARALLELISM).buildForTrainer() .parallelismStrategy() instanceof NoParallelismStrategy); - assertTrue(LearningEnvironment.builder().withParallelismStrategy(ON_DEFAULT_POOL).build() + assertTrue(LearningEnvironmentBuilder.defaultBuilder().withParallelismStrategyType(ON_DEFAULT_POOL).buildForTrainer() .parallelismStrategy() instanceof DefaultParallelismStrategy); } /** */ @Test public void withLoggingFactory() { - assertTrue(LearningEnvironment.builder().withLoggingFactory(ConsoleLogger.factory(MLLogger.VerboseLevel.HIGH)) - .build().logger() instanceof ConsoleLogger); + assertTrue(LearningEnvironmentBuilder.defaultBuilder().withLoggingFactoryDependency(part -> ConsoleLogger.factory(MLLogger.VerboseLevel.HIGH)) + .buildForTrainer().logger() instanceof ConsoleLogger); - assertTrue(LearningEnvironment.builder().withLoggingFactory(ConsoleLogger.factory(MLLogger.VerboseLevel.HIGH)) - .build().logger(this.getClass()) instanceof ConsoleLogger); + assertTrue(LearningEnvironmentBuilder.defaultBuilder().withLoggingFactoryDependency(part -> ConsoleLogger.factory(MLLogger.VerboseLevel.HIGH)) + .buildForTrainer().logger(this.getClass()) instanceof ConsoleLogger); - assertTrue(LearningEnvironment.builder().withLoggingFactory(NoOpLogger.factory()) - .build().logger() instanceof NoOpLogger); + assertTrue(LearningEnvironmentBuilder.defaultBuilder().withLoggingFactoryDependency(part -> NoOpLogger.factory()) + .buildForTrainer().logger() instanceof NoOpLogger); - assertTrue(LearningEnvironment.builder().withLoggingFactory(NoOpLogger.factory()) - .build().logger(this.getClass()) instanceof NoOpLogger); + assertTrue(LearningEnvironmentBuilder.defaultBuilder().withLoggingFactoryDependency(part -> NoOpLogger.factory()) + .buildForTrainer().logger(this.getClass()) instanceof NoOpLogger); - assertTrue(LearningEnvironment.builder().withLoggingFactory(CustomMLLogger.factory(new NullLogger())) - .build().logger() instanceof CustomMLLogger); + assertTrue(LearningEnvironmentBuilder.defaultBuilder().withLoggingFactoryDependency(part -> CustomMLLogger.factory(new NullLogger())) + .buildForTrainer().logger() instanceof CustomMLLogger); - assertTrue(LearningEnvironment.builder().withLoggingFactory(CustomMLLogger.factory(new NullLogger())) - .build().logger(this.getClass()) instanceof CustomMLLogger); + assertTrue(LearningEnvironmentBuilder.defaultBuilder().withLoggingFactoryDependency(part -> CustomMLLogger.factory(new NullLogger())) + .buildForTrainer().logger(this.getClass()) instanceof CustomMLLogger); } } http://git-wip-us.apache.org/repos/asf/ignite/blob/ff6b8eed/modules/ml/src/test/java/org/apache/ignite/ml/environment/LearningEnvironmentTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/environment/LearningEnvironmentTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/environment/LearningEnvironmentTest.java index 73192f0..4b44196 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/environment/LearningEnvironmentTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/environment/LearningEnvironmentTest.java @@ -17,17 +17,31 @@ package org.apache.ignite.ml.environment; +import java.util.Map; +import java.util.Random; import java.util.stream.Collectors; import java.util.stream.IntStream; +import org.apache.ignite.ml.Model; +import org.apache.ignite.ml.TestUtils; +import org.apache.ignite.ml.dataset.Dataset; +import org.apache.ignite.ml.dataset.DatasetBuilder; +import org.apache.ignite.ml.dataset.PartitionDataBuilder; import org.apache.ignite.ml.dataset.feature.FeatureMeta; +import org.apache.ignite.ml.dataset.primitive.builder.context.EmptyContextBuilder; +import org.apache.ignite.ml.dataset.primitive.context.EmptyContext; import org.apache.ignite.ml.environment.logging.ConsoleLogger; import org.apache.ignite.ml.environment.logging.MLLogger; import org.apache.ignite.ml.environment.parallelism.DefaultParallelismStrategy; import org.apache.ignite.ml.environment.parallelism.ParallelismStrategy; +import org.apache.ignite.ml.math.functions.IgniteBiFunction; +import org.apache.ignite.ml.math.primitives.vector.Vector; +import org.apache.ignite.ml.math.primitives.vector.VectorUtils; +import org.apache.ignite.ml.trainers.DatasetTrainer; import org.apache.ignite.ml.tree.randomforest.RandomForestRegressionTrainer; import org.apache.ignite.ml.tree.randomforest.data.FeaturesCountSelectionStrategies; import org.junit.Test; +import static org.apache.ignite.ml.TestUtils.constantModel; import static org.junit.Assert.assertEquals; /** @@ -48,13 +62,115 @@ public class LearningEnvironmentTest { .withSubSampleSize(0.3) .withSeed(0); - LearningEnvironment environment = LearningEnvironment.builder() - .withParallelismStrategy(ParallelismStrategy.Type.ON_DEFAULT_POOL) - .withLoggingFactory(ConsoleLogger.factory(MLLogger.VerboseLevel.LOW)) - .build(); - trainer.setEnvironment(environment); - assertEquals(DefaultParallelismStrategy.class, environment.parallelismStrategy().getClass()); - assertEquals(ConsoleLogger.class, environment.logger().getClass()); + LearningEnvironmentBuilder envBuilder = LearningEnvironmentBuilder.defaultBuilder() + .withParallelismStrategyType(ParallelismStrategy.Type.ON_DEFAULT_POOL) + .withLoggingFactoryDependency(part -> ConsoleLogger.factory(MLLogger.VerboseLevel.LOW)); + + trainer.withEnvironmentBuilder(envBuilder); + + assertEquals(DefaultParallelismStrategy.class, trainer.learningEnvironment().parallelismStrategy().getClass()); + assertEquals(ConsoleLogger.class, trainer.learningEnvironment().logger().getClass()); + } + + /** + * Test random number generator provided by {@link LearningEnvironment}. + * We test that: + * 1. Correct random generator is returned for each partition. + * 2. Its state is saved between compute calls (for this we do several iterations of compute). + */ + @Test + public void testRandomNumbersGenerator() { + // We make such builders that provide as functions returning partition index * iteration as random number generator nextInt + LearningEnvironmentBuilder envBuilder = TestUtils.testEnvBuilder().withRandomDependency(MockRandom::new); + int partitions = 10; + int iterations = 2; + + DatasetTrainer<Model<Object, Vector>, Void> trainer = new DatasetTrainer<Model<Object, Vector>, Void>() { + /** {@inheritDoc} */ + @Override public <K, V> Model<Object, Vector> fit(DatasetBuilder<K, V> datasetBuilder, + IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, Void> lbExtractor) { + Dataset<EmptyContext, TestUtils.DataWrapper<Integer>> ds = datasetBuilder.build(envBuilder, + new EmptyContextBuilder<>(), + (PartitionDataBuilder<K, V, EmptyContext, TestUtils.DataWrapper<Integer>>)(env, upstreamData, upstreamDataSize, ctx) -> + TestUtils.DataWrapper.of(env.partition())); + + Vector v = null; + for (int iter = 0; iter < iterations; iter++) { + v = ds.compute((dw, env) -> VectorUtils.fill(-1, partitions).set(env.partition(), env.randomNumbersGenerator().nextInt()), + (v1, v2) -> zipOverridingEmpty(v1, v2, -1)); + } + return constantModel(v); + } + + /** {@inheritDoc} */ + @Override protected boolean checkState(Model<Object, Vector> mdl) { + return false; + } + + /** {@inheritDoc} */ + @Override protected <K, V> Model<Object, Vector> updateModel(Model<Object, Vector> mdl, + DatasetBuilder<K, V> datasetBuilder, + IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, Void> lbExtractor) { + return null; + } + }; + trainer.withEnvironmentBuilder(envBuilder); + Model<Object, Vector> mdl = trainer.fit(getCacheMock(partitions), partitions, null, null); + + Vector exp = VectorUtils.zeroes(partitions); + for (int i = 0; i < partitions; i++) + exp.set(i, i * iterations); + + + Vector res = mdl.apply(null); + assertEquals(exp, res); + } + + /** + * For given two vectors {@code v2, v2} produce vector {@code v} where each component of {@code v} + * is produced from corresponding components {@code c1, c2} of {@code v1, v2} respectfully in following way + * {@code c = c1 != empty ? c1 : c2}. For example, zipping [2, -1, -1], [-1, 3, -1] will result in [2, 3, -1]. + * + * @param v1 First vector. + * @param v2 Second vector. + * @param empty Value treated as empty. + * @return Result of zipping as described above. + */ + private static Vector zipOverridingEmpty(Vector v1, Vector v2, double empty) { + return v1 != null ? (v2 != null ? VectorUtils.zipWith(v1, v2, (d1, d2) -> d1 != empty ? d1 : d2) : v1) : v2; + } + + /** Get cache mock */ + private Map<Integer, Integer> getCacheMock(int partsCnt) { + return IntStream.range(0, partsCnt).boxed().collect(Collectors.toMap(x -> x, x -> x)); + } + + /** Mock random numners generator. */ + private static class MockRandom extends Random { + /** Serial version uuid. */ + private static final long serialVersionUID = -7738558243461112988L; + + /** Start value. */ + private int startVal; + + /** Iteration. */ + private int iter; + + /** + * Constructs instance of this class with a specified start value. + * + * @param startVal Start value. + */ + MockRandom(int startVal) { + this.startVal = startVal; + iter = 0; + } + + /** {@inheritDoc} */ + @Override public int nextInt() { + iter++; + return startVal * iter; + } } }