IGNITE-9514 :[ML] Reduce time for the updating models on many partitions this closes #4788
Project: http://git-wip-us.apache.org/repos/asf/ignite/repo Commit: http://git-wip-us.apache.org/repos/asf/ignite/commit/eff5751e Tree: http://git-wip-us.apache.org/repos/asf/ignite/tree/eff5751e Diff: http://git-wip-us.apache.org/repos/asf/ignite/diff/eff5751e Branch: refs/heads/ignite-gg-14206 Commit: eff5751e49be6b7fa5b07e2dee2faaa9e2574417 Parents: 2f72fe7 Author: zaleslaw <zaleslaw....@gmail.com> Authored: Tue Sep 25 15:51:54 2018 +0300 Committer: Yury Babak <yba...@gridgain.com> Committed: Tue Sep 25 15:51:54 2018 +0300 ---------------------------------------------------------------------- .../RandomForestClassificationExample.java | 2 +- .../RandomForestRegressionExample.java | 2 +- .../boosting/GDBBinaryClassifierTrainer.java | 10 ++--- .../boosting/GDBLearningStrategy.java | 18 ++++---- .../impl/bootstrapping/BootstrappedVector.java | 6 +-- .../ml/environment/LearningEnvironment.java | 4 +- .../ml/environment/logging/ConsoleLogger.java | 6 +-- .../ml/environment/logging/CustomMLLogger.java | 30 ++++++------- .../ignite/ml/environment/logging/MLLogger.java | 9 +++- .../parallelism/NoParallelismStrategy.java | 14 +++--- .../classification/KNNClassificationModel.java | 13 +++--- .../org/apache/ignite/ml/nn/MLPTrainer.java | 8 ++-- .../maxabsscaling/MaxAbsScalerTrainer.java | 10 ++--- ...VMLinearMultiClassClassificationTrainer.java | 9 ++-- .../ignite/ml/trainers/DatasetTrainer.java | 4 +- .../org/apache/ignite/ml/tree/DecisionTree.java | 14 +++--- .../tree/DecisionTreeClassificationTrainer.java | 6 +-- .../boosting/GDBOnTreesLearningStrategy.java | 19 +++++---- .../ignite/ml/tree/data/DecisionTreeData.java | 16 +++---- .../ml/tree/data/DecisionTreeDataBuilder.java | 6 +-- .../ignite/ml/tree/data/TreeDataIndex.java | 32 +++++++------- .../impurity/ImpurityMeasureCalculator.java | 18 ++++---- .../gini/GiniImpurityMeasureCalculator.java | 36 ++++++++-------- .../mse/MSEImpurityMeasureCalculator.java | 28 ++++++------ .../tree/randomforest/RandomForestTrainer.java | 10 ++--- .../ml/tree/randomforest/data/NodeSplit.java | 14 +++--- .../ml/tree/randomforest/data/TreeNode.java | 26 +++++------ .../impurity/ImpurityHistogramsComputer.java | 4 +- .../data/statistics/LeafValuesComputer.java | 2 +- .../org/apache/ignite/ml/util/ModelTrace.java | 6 +-- .../ignite/ml/clustering/KMeansTrainerTest.java | 11 ++--- .../apache/ignite/ml/common/TrainerTest.java | 4 +- .../ml/composition/boosting/GDBTrainerTest.java | 3 +- .../MeanValuePredictionsAggregatorTest.java | 2 + .../OnMajorityPredictionsAggregatorTest.java | 1 + .../ml/dataset/feature/ObjectHistogramTest.java | 39 ++++++++--------- .../ml/environment/LearningEnvironmentTest.java | 2 +- .../ignite/ml/knn/ANNClassificationTest.java | 18 +++++--- .../apache/ignite/ml/knn/KNNRegressionTest.java | 26 +---------- .../apache/ignite/ml/math/VectorUtilsTest.java | 9 ++-- .../ml/math/isolve/lsqr/LSQROnHeapTest.java | 24 +---------- .../ignite/ml/pipeline/PipelineMdlTest.java | 5 +++ .../binarization/BinarizationTrainerTest.java | 25 +---------- .../encoding/EncoderTrainerTest.java | 25 +---------- .../imputing/ImputerTrainerTest.java | 25 +---------- .../MaxAbsScalerPreprocessorTest.java | 4 +- .../maxabsscaling/MaxAbsScalerTrainerTest.java | 25 +---------- .../minmaxscaling/MinMaxScalerTrainerTest.java | 25 +---------- .../normalization/NormalizationTrainerTest.java | 25 +---------- .../linear/LinearRegressionLSQRTrainerTest.java | 14 +++--- .../linear/LinearRegressionSGDTrainerTest.java | 14 +++--- .../logistic/LogRegMultiClassTrainerTest.java | 10 ++--- .../LogisticRegressionSGDTrainerTest.java | 14 +++--- .../tree/DecisionTreeRegressionTrainerTest.java | 4 +- .../ml/tree/data/DecisionTreeDataTest.java | 6 +-- .../ignite/ml/tree/data/TreeDataIndexTest.java | 36 ++++++++-------- .../gini/GiniImpurityMeasureCalculatorTest.java | 12 +++--- .../mse/MSEImpurityMeasureCalculatorTest.java | 6 +-- .../RandomForestClassifierTrainerTest.java | 44 ++++--------------- .../RandomForestRegressionTrainerTest.java | 41 ++++-------------- .../ml/tree/randomforest/RandomForestTest.java | 4 +- .../data/impurity/GiniFeatureHistogramTest.java | 19 +++++---- .../data/impurity/ImpurityHistogramTest.java | 45 ++++++++++++++------ .../data/impurity/MSEHistogramTest.java | 15 ++++--- ...ormalDistributionStatisticsComputerTest.java | 45 ++++++++++---------- 65 files changed, 422 insertions(+), 587 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/ignite/blob/eff5751e/examples/src/main/java/org/apache/ignite/examples/ml/tree/randomforest/RandomForestClassificationExample.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/tree/randomforest/RandomForestClassificationExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/tree/randomforest/RandomForestClassificationExample.java index 6194153..4693744 100644 --- a/examples/src/main/java/org/apache/ignite/examples/ml/tree/randomforest/RandomForestClassificationExample.java +++ b/examples/src/main/java/org/apache/ignite/examples/ml/tree/randomforest/RandomForestClassificationExample.java @@ -70,7 +70,7 @@ public class RandomForestClassificationExample { RandomForestClassifierTrainer classifier = new RandomForestClassifierTrainer( IntStream.range(0, data[0].length - 1).mapToObj( x -> new FeatureMeta("", idx.getAndIncrement(), false)).collect(Collectors.toList()) - ).withCountOfTrees(101) + ).withAmountOfTrees(101) .withFeaturesCountSelectionStrgy(FeaturesCountSelectionStrategies.ONE_THIRD) .withMaxDepth(4) .withMinImpurityDelta(0.) http://git-wip-us.apache.org/repos/asf/ignite/blob/eff5751e/examples/src/main/java/org/apache/ignite/examples/ml/tree/randomforest/RandomForestRegressionExample.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/tree/randomforest/RandomForestRegressionExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/tree/randomforest/RandomForestRegressionExample.java index 5f010f2..ee0c1c2 100644 --- a/examples/src/main/java/org/apache/ignite/examples/ml/tree/randomforest/RandomForestRegressionExample.java +++ b/examples/src/main/java/org/apache/ignite/examples/ml/tree/randomforest/RandomForestRegressionExample.java @@ -74,7 +74,7 @@ public class RandomForestRegressionExample { RandomForestRegressionTrainer trainer = new RandomForestRegressionTrainer( IntStream.range(0, data[0].length - 1).mapToObj( x -> new FeatureMeta("", idx.getAndIncrement(), false)).collect(Collectors.toList()) - ).withCountOfTrees(101) + ).withAmountOfTrees(101) .withFeaturesCountSelectionStrgy(FeaturesCountSelectionStrategies.ONE_THIRD) .withMaxDepth(4) .withMinImpurityDelta(0.) http://git-wip-us.apache.org/repos/asf/ignite/blob/eff5751e/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/GDBBinaryClassifierTrainer.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/GDBBinaryClassifierTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/GDBBinaryClassifierTrainer.java index f6ddfed..8682a46 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/GDBBinaryClassifierTrainer.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/GDBBinaryClassifierTrainer.java @@ -82,13 +82,13 @@ public abstract class GDBBinaryClassifierTrainer extends GDBTrainer { ); if (uniqLabels != null && uniqLabels.size() == 2) { - ArrayList<Double> lblsArray = new ArrayList<>(uniqLabels); - externalFirstCls = lblsArray.get(0); - externalSecondCls = lblsArray.get(1); + ArrayList<Double> lblsArr = new ArrayList<>(uniqLabels); + externalFirstCls = lblsArr.get(0); + externalSecondCls = lblsArr.get(1); return true; - } else { - return false; } + else + return false; } /** {@inheritDoc} */ http://git-wip-us.apache.org/repos/asf/ignite/blob/eff5751e/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/GDBLearningStrategy.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/GDBLearningStrategy.java b/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/GDBLearningStrategy.java index 737495e..e689b91 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/GDBLearningStrategy.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/GDBLearningStrategy.java @@ -57,7 +57,7 @@ public class GDBLearningStrategy { protected IgniteSupplier<DatasetTrainer<? extends Model<Vector, Double>, Double>> baseMdlTrainerBuilder; /** Mean label value. */ - protected double meanLabelValue; + protected double meanLbVal; /** Sample size. */ protected long sampleSize; @@ -111,7 +111,7 @@ public class GDBLearningStrategy { for (int i = 0; i < cntOfIterations; i++) { double[] weights = Arrays.copyOf(compositionWeights, models.size()); - WeightedPredictionsAggregator aggregator = new WeightedPredictionsAggregator(weights, meanLabelValue); + WeightedPredictionsAggregator aggregator = new WeightedPredictionsAggregator(weights, meanLbVal); ModelsComposition currComposition = new ModelsComposition(models, aggregator); if (convCheck.isConverged(datasetBuilder, currComposition)) break; @@ -142,13 +142,13 @@ public class GDBLearningStrategy { if(mdlToUpdate != null) { models.addAll(mdlToUpdate.getModels()); WeightedPredictionsAggregator aggregator = (WeightedPredictionsAggregator) mdlToUpdate.getPredictionsAggregator(); - meanLabelValue = aggregator.getBias(); + meanLbVal = aggregator.getBias(); compositionWeights = new double[models.size() + cntOfIterations]; for(int i = 0; i < models.size(); i++) compositionWeights[i] = aggregator.getWeights()[i]; - } else { - compositionWeights = new double[cntOfIterations]; } + else + compositionWeights = new double[cntOfIterations]; Arrays.fill(compositionWeights, models.size(), compositionWeights.length, defaultGradStepSize); return models; @@ -208,10 +208,10 @@ public class GDBLearningStrategy { /** * Sets mean label value. * - * @param meanLabelValue Mean label value. + * @param meanLbVal Mean label value. */ - public GDBLearningStrategy withMeanLabelValue(double meanLabelValue) { - this.meanLabelValue = meanLabelValue; + public GDBLearningStrategy withMeanLabelValue(double meanLbVal) { + this.meanLbVal = meanLbVal; return this; } @@ -262,6 +262,6 @@ public class GDBLearningStrategy { /** */ public double getMeanValue() { - return meanLabelValue; + return meanLbVal; } } http://git-wip-us.apache.org/repos/asf/ignite/blob/eff5751e/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/bootstrapping/BootstrappedVector.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/bootstrapping/BootstrappedVector.java b/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/bootstrapping/BootstrappedVector.java index aedd0fd..573b256 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/bootstrapping/BootstrappedVector.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/bootstrapping/BootstrappedVector.java @@ -68,9 +68,9 @@ public class BootstrappedVector extends LabeledVector<Vector, Double> { /** {@inheritDoc} */ @Override public int hashCode() { - int result = super.hashCode(); - result = 31 * result + Arrays.hashCode(counters); - return result; + int res = super.hashCode(); + res = 31 * res + Arrays.hashCode(counters); + return res; } /** {@inheritDoc} */ http://git-wip-us.apache.org/repos/asf/ignite/blob/eff5751e/modules/ml/src/main/java/org/apache/ignite/ml/environment/LearningEnvironment.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/environment/LearningEnvironment.java b/modules/ml/src/main/java/org/apache/ignite/ml/environment/LearningEnvironment.java index 2b94a2f..f5fb693 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/environment/LearningEnvironment.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/environment/LearningEnvironment.java @@ -41,9 +41,9 @@ public interface LearningEnvironment { /** * Returns an instance of logger for specific class. * - * @param forClass Logging class context. + * @param forCls Logging class context. */ - public <T> MLLogger logger(Class<T> forClass); + public <T> MLLogger logger(Class<T> forCls); /** * Creates an instance of LearningEnvironmentBuilder. http://git-wip-us.apache.org/repos/asf/ignite/blob/eff5751e/modules/ml/src/main/java/org/apache/ignite/ml/environment/logging/ConsoleLogger.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/environment/logging/ConsoleLogger.java b/modules/ml/src/main/java/org/apache/ignite/ml/environment/logging/ConsoleLogger.java index 7efa29c..e064fc3 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/environment/logging/ConsoleLogger.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/environment/logging/ConsoleLogger.java @@ -28,7 +28,7 @@ public class ConsoleLogger implements MLLogger { /** Maximum Verbose level. */ private final VerboseLevel maxVerboseLevel; /** Class name. */ - private final String className; + private final String clsName; /** * Creates an instance of ConsoleLogger. @@ -37,7 +37,7 @@ public class ConsoleLogger implements MLLogger { * @param clsName Class name. */ private ConsoleLogger(VerboseLevel maxVerboseLevel, String clsName) { - this.className = clsName; + this.clsName = clsName; this.maxVerboseLevel = maxVerboseLevel; } @@ -75,7 +75,7 @@ public class ConsoleLogger implements MLLogger { */ private void print(VerboseLevel verboseLevel, String line) { if (this.maxVerboseLevel.compareTo(verboseLevel) >= 0) - System.out.println(String.format("%s [%s] %s", className, verboseLevel.name(), line)); + System.out.println(String.format("%s [%s] %s", clsName, verboseLevel.name(), line)); } /** http://git-wip-us.apache.org/repos/asf/ignite/blob/eff5751e/modules/ml/src/main/java/org/apache/ignite/ml/environment/logging/CustomMLLogger.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/environment/logging/CustomMLLogger.java b/modules/ml/src/main/java/org/apache/ignite/ml/environment/logging/CustomMLLogger.java index 65bc4cb..90aed14 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/environment/logging/CustomMLLogger.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/environment/logging/CustomMLLogger.java @@ -27,29 +27,29 @@ import org.apache.ignite.ml.math.primitives.vector.Vector; */ public class CustomMLLogger implements MLLogger { /** Ignite logger instance. */ - private final IgniteLogger logger; + private final IgniteLogger log; /** * Creates an instance of CustomMLLogger. * - * @param logger Basic Logger. + * @param log Basic Logger. */ - private CustomMLLogger(IgniteLogger logger) { - this.logger = logger; + private CustomMLLogger(IgniteLogger log) { + this.log = log; } /** * Returns factory for OnIgniteLogger instantiating. * - * @param rootLogger Root logger. + * @param rootLog Root logger. */ - public static Factory factory(IgniteLogger rootLogger) { - return new Factory(rootLogger); + public static Factory factory(IgniteLogger rootLog) { + return new Factory(rootLog); } /** {@inheritDoc} */ @Override public Vector log(Vector vector) { - Tracer.showAscii(vector, logger); + Tracer.showAscii(vector, log); return vector; } @@ -73,10 +73,10 @@ public class CustomMLLogger implements MLLogger { private void log(VerboseLevel verboseLevel, String line) { switch (verboseLevel) { case LOW: - logger.info(line); + log.info(line); break; case HIGH: - logger.debug(line); + log.debug(line); break; } } @@ -86,20 +86,20 @@ public class CustomMLLogger implements MLLogger { */ private static class Factory implements MLLogger.Factory { /** Root logger. */ - private IgniteLogger rootLogger; + private IgniteLogger rootLog; /** * Creates an instance of factory. * - * @param rootLogger Root logger. + * @param rootLog Root logger. */ - public Factory(IgniteLogger rootLogger) { - this.rootLogger = rootLogger; + public Factory(IgniteLogger rootLog) { + this.rootLog = rootLog; } /** {@inheritDoc} */ @Override public <T> MLLogger create(Class<T> targetCls) { - return new CustomMLLogger(rootLogger.getLogger(targetCls)); + return new CustomMLLogger(rootLog.getLogger(targetCls)); } } } http://git-wip-us.apache.org/repos/asf/ignite/blob/eff5751e/modules/ml/src/main/java/org/apache/ignite/ml/environment/logging/MLLogger.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/environment/logging/MLLogger.java b/modules/ml/src/main/java/org/apache/ignite/ml/environment/logging/MLLogger.java index 872b947..b2b4739 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/environment/logging/MLLogger.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/environment/logging/MLLogger.java @@ -28,7 +28,14 @@ public interface MLLogger { * Logging verbose level. */ enum VerboseLevel { - OFF, LOW, HIGH + /** Disabled. */ + OFF, + + /** Low. */ + LOW, + + /** High. */ + HIGH } /** http://git-wip-us.apache.org/repos/asf/ignite/blob/eff5751e/modules/ml/src/main/java/org/apache/ignite/ml/environment/parallelism/NoParallelismStrategy.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/environment/parallelism/NoParallelismStrategy.java b/modules/ml/src/main/java/org/apache/ignite/ml/environment/parallelism/NoParallelismStrategy.java index 5f605a7..759e06a 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/environment/parallelism/NoParallelismStrategy.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/environment/parallelism/NoParallelismStrategy.java @@ -46,15 +46,17 @@ public class NoParallelismStrategy implements ParallelismStrategy { * @param <T> Type of result. */ public static class Stub<T> implements Promise<T> { - private T result; + + /** Result. */ + private T res; /** * Create an instance of Stub * - * @param result Execution result. + * @param res Execution result. */ - public Stub(T result) { - this.result = result; + public Stub(T res) { + this.res = res; } /** {@inheritDoc} */ @@ -74,14 +76,14 @@ public class NoParallelismStrategy implements ParallelismStrategy { /** {@inheritDoc} */ @Override public T get() throws InterruptedException, ExecutionException { - return result; + return res; } /** {@inheritDoc} */ @Override public T get(long timeout, @NotNull TimeUnit unit) throws InterruptedException, ExecutionException, TimeoutException { - return result; + return res; } } } http://git-wip-us.apache.org/repos/asf/ignite/blob/eff5751e/modules/ml/src/main/java/org/apache/ignite/ml/knn/classification/KNNClassificationModel.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/knn/classification/KNNClassificationModel.java b/modules/ml/src/main/java/org/apache/ignite/ml/knn/classification/KNNClassificationModel.java index 0d03ee5..3de73bd 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/knn/classification/KNNClassificationModel.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/knn/classification/KNNClassificationModel.java @@ -63,9 +63,9 @@ public class KNNClassificationModel extends NNClassificationModel implements Exp List<LabeledVector> neighbors = findKNearestNeighbors(v); return classify(neighbors, v, stgy); - } else { - throw new IllegalStateException("The train kNN dataset is null"); } + else + throw new IllegalStateException("The train kNN dataset is null"); } /** */ @@ -91,6 +91,7 @@ public class KNNClassificationModel extends NNClassificationModel implements Exp return Arrays.asList(getKClosestVectors(neighborsToFilter, getDistances(v, neighborsToFilter))); } + /** */ private List<LabeledVector> findKNearestNeighborsInDataset(Vector v, Dataset<EmptyContext, LabeledVectorSet<Double, LabeledVector>> dataset) { List<LabeledVector> neighborsFromPartitions = dataset.compute(data -> { @@ -137,10 +138,10 @@ public class KNNClassificationModel extends NNClassificationModel implements Exp /** * Copy parameters from other model and save all datasets from it. * - * @param model Model. + * @param mdl Model. */ - public void copyStateFrom(KNNClassificationModel model) { - this.copyParametersFrom(model); - datasets.addAll(model.datasets); + public void copyStateFrom(KNNClassificationModel mdl) { + this.copyParametersFrom(mdl); + datasets.addAll(mdl.datasets); } } http://git-wip-us.apache.org/repos/asf/ignite/blob/eff5751e/modules/ml/src/main/java/org/apache/ignite/ml/nn/MLPTrainer.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/nn/MLPTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/nn/MLPTrainer.java index 1cac909..c75c5bb 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/nn/MLPTrainer.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/nn/MLPTrainer.java @@ -115,7 +115,7 @@ public class MLPTrainer<P extends Serializable> extends MultiLabelDatasetTrainer } /** {@inheritDoc} */ - @Override protected <K, V> MultilayerPerceptron updateModel(MultilayerPerceptron lastLearnedModel, + @Override protected <K, V> MultilayerPerceptron updateModel(MultilayerPerceptron lastLearnedMdl, DatasetBuilder<K, V> datasetBuilder, IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, double[]> lbExtractor) { @@ -128,8 +128,8 @@ public class MLPTrainer<P extends Serializable> extends MultiLabelDatasetTrainer new SimpleLabeledDatasetDataBuilder<>(featureExtractor, lbExtractor) )) { MultilayerPerceptron mdl; - if (lastLearnedModel != null) - mdl = lastLearnedModel; + if (lastLearnedMdl != null) + mdl = lastLearnedMdl; else { MLPArchitecture arch = archSupplier.apply(dataset); mdl = new MultilayerPerceptron(arch, new RandomInitializer(seed)); @@ -196,7 +196,7 @@ public class MLPTrainer<P extends Serializable> extends MultiLabelDatasetTrainer ); if (totUp == null) - return getLastTrainedModelOrThrowEmptyDatasetException(lastLearnedModel); + return getLastTrainedModelOrThrowEmptyDatasetException(lastLearnedMdl); P update = updatesStgy.allUpdatesReducer().apply(totUp); mdl = updater.update(mdl, update); http://git-wip-us.apache.org/repos/asf/ignite/blob/eff5751e/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/maxabsscaling/MaxAbsScalerTrainer.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/maxabsscaling/MaxAbsScalerTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/maxabsscaling/MaxAbsScalerTrainer.java index d3e5734..c8b1dca 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/maxabsscaling/MaxAbsScalerTrainer.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/maxabsscaling/MaxAbsScalerTrainer.java @@ -69,12 +69,12 @@ public class MaxAbsScalerTrainer<K, V> implements PreprocessingTrainer<K, V, Vec if (b == null) return a; - double[] result = new double[a.length]; + double[] res = new double[a.length]; - for (int i = 0; i < result.length; i++) { - result[i] = Math.max(Math.abs(a[i]), Math.abs(b[i])); - } - return result; + for (int i = 0; i < res.length; i++) + res[i] = Math.max(Math.abs(a[i]), Math.abs(b[i])); + + return res; }); return new MaxAbsScalerPreprocessor<>(maxAbs, basePreprocessor); } http://git-wip-us.apache.org/repos/asf/ignite/blob/eff5751e/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearMultiClassClassificationTrainer.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearMultiClassClassificationTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearMultiClassClassificationTrainer.java index 7cbb1dc..ec60034 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearMultiClassClassificationTrainer.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearMultiClassClassificationTrainer.java @@ -96,12 +96,13 @@ public class SVMLinearMultiClassClassificationTrainer return 0.0; }; - SVMLinearBinaryClassificationModel model; + SVMLinearBinaryClassificationModel updatedMdl; + if (mdl == null) - model = learnNewModel(trainer, datasetBuilder, featureExtractor, lbTransformer); + updatedMdl = learnNewModel(trainer, datasetBuilder, featureExtractor, lbTransformer); else - model = updateModel(mdl, clsLb, trainer, datasetBuilder, featureExtractor, lbTransformer); - multiClsMdl.add(clsLb, model); + updatedMdl = updateModel(mdl, clsLb, trainer, datasetBuilder, featureExtractor, lbTransformer); + multiClsMdl.add(clsLb, updatedMdl); }); return multiClsMdl; http://git-wip-us.apache.org/repos/asf/ignite/blob/eff5751e/modules/ml/src/main/java/org/apache/ignite/ml/trainers/DatasetTrainer.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trainers/DatasetTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/trainers/DatasetTrainer.java index 490c53d..5c3913e 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/trainers/DatasetTrainer.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/trainers/DatasetTrainer.java @@ -70,9 +70,9 @@ public abstract class DatasetTrainer<M extends Model, L> { IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, L> lbExtractor) { if(mdl != null) { - if(checkState(mdl)) { + if (checkState(mdl)) return updateModel(mdl, datasetBuilder, featureExtractor, lbExtractor); - } else { + else { environment.logger(getClass()).log( MLLogger.VerboseLevel.HIGH, "Model cannot be updated because of initial state of " + http://git-wip-us.apache.org/repos/asf/ignite/blob/eff5751e/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTree.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTree.java b/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTree.java index 45774cb..b40ca93 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTree.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTree.java @@ -164,7 +164,7 @@ public abstract class DecisionTree<T extends ImpurityMeasure<T>> extends Dataset private StepFunction<T>[] calculateImpurityForAllColumns(Dataset<EmptyContext, DecisionTreeData> dataset, TreeFilter filter, ImpurityMeasureCalculator<T> impurityCalc, int depth) { - StepFunction<T>[] result = dataset.compute( + return dataset.compute( part -> { if (compressor != null) return compressor.compress(impurityCalc.calculate(part, filter, depth)); @@ -172,8 +172,6 @@ public abstract class DecisionTree<T extends ImpurityMeasure<T>> extends Dataset return impurityCalc.calculate(part, filter, depth); }, this::reduce ); - - return result; } /** @@ -314,16 +312,16 @@ public abstract class DecisionTree<T extends ImpurityMeasure<T>> extends Dataset .append(String.format("%.4f", leaf.getVal())); } else if (node instanceof DecisionTreeConditionalNode) { - DecisionTreeConditionalNode condition = (DecisionTreeConditionalNode)node; + DecisionTreeConditionalNode cond = (DecisionTreeConditionalNode)node; String prefix = depth == 0 ? "" : (isThen ? "then " : "else "); builder.append(String.format("%sif (x", prefix)) - .append(condition.getCol()) + .append(cond.getCol()) .append(" > ") - .append(String.format("%.4f", condition.getThreshold())) + .append(String.format("%.4f", cond.getThreshold())) .append(pretty ? ")\n" : ") "); - printTree(condition.getThenNode(), depth + 1, builder, pretty, true); + printTree(cond.getThenNode(), depth + 1, builder, pretty, true); builder.append(pretty ? "\n" : " "); - printTree(condition.getElseNode(), depth + 1, builder, pretty, false); + printTree(cond.getElseNode(), depth + 1, builder, pretty, false); } else throw new IllegalArgumentException(); http://git-wip-us.apache.org/repos/asf/ignite/blob/eff5751e/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTreeClassificationTrainer.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTreeClassificationTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTreeClassificationTrainer.java index 91ec8e1..58552f4 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTreeClassificationTrainer.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTreeClassificationTrainer.java @@ -87,11 +87,11 @@ public class DecisionTreeClassificationTrainer extends DecisionTree<GiniImpurity /** * Sets useIndex parameter and returns trainer instance. * - * @param useIndex Use index. + * @param useIdx Use index. * @return Decision tree trainer. */ - public DecisionTreeClassificationTrainer withUseIndex(boolean useIndex) { - this.usingIdx = useIndex; + public DecisionTreeClassificationTrainer withUseIndex(boolean useIdx) { + this.usingIdx = useIdx; return this; } http://git-wip-us.apache.org/repos/asf/ignite/blob/eff5751e/modules/ml/src/main/java/org/apache/ignite/ml/tree/boosting/GDBOnTreesLearningStrategy.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/tree/boosting/GDBOnTreesLearningStrategy.java b/modules/ml/src/main/java/org/apache/ignite/ml/tree/boosting/GDBOnTreesLearningStrategy.java index 6ebbda1..caac168 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/tree/boosting/GDBOnTreesLearningStrategy.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/tree/boosting/GDBOnTreesLearningStrategy.java @@ -43,15 +43,16 @@ import org.apache.ignite.ml.tree.data.DecisionTreeDataBuilder; * several learning iterations. */ public class GDBOnTreesLearningStrategy extends GDBLearningStrategy { - private boolean useIndex; + /** Use index. */ + private boolean useIdx; /** * Create an instance of learning strategy. * - * @param useIndex Use index. + * @param useIdx Use index. */ - public GDBOnTreesLearningStrategy(boolean useIndex) { - this.useIndex = useIndex; + public GDBOnTreesLearningStrategy(boolean useIdx) { + this.useIdx = useIdx; } /** {@inheritDoc} */ @@ -70,23 +71,23 @@ public class GDBOnTreesLearningStrategy extends GDBLearningStrategy { try (Dataset<EmptyContext, DecisionTreeData> dataset = datasetBuilder.build( new EmptyContextBuilder<>(), - new DecisionTreeDataBuilder<>(featureExtractor, lbExtractor, useIndex) + new DecisionTreeDataBuilder<>(featureExtractor, lbExtractor, useIdx) )) { for (int i = 0; i < cntOfIterations; i++) { double[] weights = Arrays.copyOf(compositionWeights, models.size()); - WeightedPredictionsAggregator aggregator = new WeightedPredictionsAggregator(weights, meanLabelValue); + WeightedPredictionsAggregator aggregator = new WeightedPredictionsAggregator(weights, meanLbVal); ModelsComposition currComposition = new ModelsComposition(models, aggregator); if(convCheck.isConverged(dataset, currComposition)) break; dataset.compute(part -> { - if(part.getCopyOfOriginalLabels() == null) - part.setCopyOfOriginalLabels(Arrays.copyOf(part.getLabels(), part.getLabels().length)); + if (part.getCopiedOriginalLabels() == null) + part.setCopiedOriginalLabels(Arrays.copyOf(part.getLabels(), part.getLabels().length)); for(int j = 0; j < part.getLabels().length; j++) { double mdlAnswer = currComposition.apply(VectorUtils.of(part.getFeatures()[j])); - double originalLbVal = externalLbToInternalMapping.apply(part.getCopyOfOriginalLabels()[j]); + double originalLbVal = externalLbToInternalMapping.apply(part.getCopiedOriginalLabels()[j]); part.getLabels()[j] = -loss.gradient(sampleSize, originalLbVal, mdlAnswer); } }); http://git-wip-us.apache.org/repos/asf/ignite/blob/eff5751e/modules/ml/src/main/java/org/apache/ignite/ml/tree/data/DecisionTreeData.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/tree/data/DecisionTreeData.java b/modules/ml/src/main/java/org/apache/ignite/ml/tree/data/DecisionTreeData.java index b8a16dc..335f751 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/tree/data/DecisionTreeData.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/tree/data/DecisionTreeData.java @@ -28,13 +28,13 @@ import org.apache.ignite.ml.tree.TreeFilter; */ public class DecisionTreeData extends FeatureMatrixWithLabelsOnHeapData implements AutoCloseable { /** Copy of vector with original labels. Auxiliary for Gradient Boosting on Trees.*/ - private double[] copyOfOriginalLabels; + private double[] copiedOriginalLabels; /** Indexes cache. */ private final List<TreeDataIndex> indexesCache; /** Build index. */ - private final boolean buildIndex; + private final boolean buildIdx; /** * Constructs a new instance of decision tree data. @@ -45,7 +45,7 @@ public class DecisionTreeData extends FeatureMatrixWithLabelsOnHeapData implemen */ public DecisionTreeData(double[][] features, double[] labels, boolean buildIdx) { super(features, labels); - this.buildIndex = buildIdx; + this.buildIdx = buildIdx; indexesCache = new ArrayList<>(); if (buildIdx) @@ -81,7 +81,7 @@ public class DecisionTreeData extends FeatureMatrixWithLabelsOnHeapData implemen } } - return new DecisionTreeData(newFeatures, newLabels, buildIndex); + return new DecisionTreeData(newFeatures, newLabels, buildIdx); } /** @@ -129,13 +129,13 @@ public class DecisionTreeData extends FeatureMatrixWithLabelsOnHeapData implemen } /** */ - public double[] getCopyOfOriginalLabels() { - return copyOfOriginalLabels; + public double[] getCopiedOriginalLabels() { + return copiedOriginalLabels; } /** */ - public void setCopyOfOriginalLabels(double[] copyOfOriginalLabels) { - this.copyOfOriginalLabels = copyOfOriginalLabels; + public void setCopiedOriginalLabels(double[] copiedOriginalLabels) { + this.copiedOriginalLabels = copiedOriginalLabels; } /** {@inheritDoc} */ http://git-wip-us.apache.org/repos/asf/ignite/blob/eff5751e/modules/ml/src/main/java/org/apache/ignite/ml/tree/data/DecisionTreeDataBuilder.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/tree/data/DecisionTreeDataBuilder.java b/modules/ml/src/main/java/org/apache/ignite/ml/tree/data/DecisionTreeDataBuilder.java index 6678218..4436b07 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/tree/data/DecisionTreeDataBuilder.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/tree/data/DecisionTreeDataBuilder.java @@ -43,7 +43,7 @@ public class DecisionTreeDataBuilder<K, V, C extends Serializable> private final IgniteBiFunction<K, V, Double> lbExtractor; /** Build index. */ - private final boolean buildIndex; + private final boolean buildIdx; /** * Constructs a new instance of decision tree data builder. @@ -56,7 +56,7 @@ public class DecisionTreeDataBuilder<K, V, C extends Serializable> IgniteBiFunction<K, V, Double> lbExtractor, boolean buildIdx) { this.featureExtractor = featureExtractor; this.lbExtractor = lbExtractor; - this.buildIndex = buildIdx; + this.buildIdx = buildIdx; } /** {@inheritDoc} */ @@ -75,6 +75,6 @@ public class DecisionTreeDataBuilder<K, V, C extends Serializable> ptr++; } - return new DecisionTreeData(features, labels, buildIndex); + return new DecisionTreeData(features, labels, buildIdx); } } http://git-wip-us.apache.org/repos/asf/ignite/blob/eff5751e/modules/ml/src/main/java/org/apache/ignite/ml/tree/data/TreeDataIndex.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/tree/data/TreeDataIndex.java b/modules/ml/src/main/java/org/apache/ignite/ml/tree/data/TreeDataIndex.java index 88ce190..a86f78d 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/tree/data/TreeDataIndex.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/tree/data/TreeDataIndex.java @@ -26,7 +26,7 @@ import org.apache.ignite.ml.tree.TreeFilter; */ public class TreeDataIndex { /** Index containing IDs of rows as if they is sorted by feature values. */ - private final int[][] index; + private final int[][] idx; /** Original features table. */ private final double[][] features; @@ -48,9 +48,9 @@ public class TreeDataIndex { int cols = features.length == 0 ? 0 : features[0].length; double[][] featuresCp = new double[rows][cols]; - index = new int[rows][cols]; + idx = new int[rows][cols]; for (int row = 0; row < rows; row++) { - Arrays.fill(index[row], row); + Arrays.fill(idx[row], row); featuresCp[row] = Arrays.copyOf(features[row], cols); } @@ -61,12 +61,12 @@ public class TreeDataIndex { /** * Constructs an instance of TreeDataIndex * - * @param indexProj Index projection. + * @param idxProj Index projection. * @param features Features. * @param labels Labels. */ - private TreeDataIndex(int[][] indexProj, double[][] features, double[] labels) { - this.index = indexProj; + private TreeDataIndex(int[][] idxProj, double[][] features, double[] labels) { + this.idx = idxProj; this.features = features; this.labels = labels; } @@ -79,7 +79,7 @@ public class TreeDataIndex { * @return Label value. */ public double labelInSortedOrder(int k, int featureId) { - return labels[index[k][featureId]]; + return labels[idx[k][featureId]]; } /** @@ -90,7 +90,7 @@ public class TreeDataIndex { * @return Features vector. */ public double[] featuresInSortedOrder(int k, int featureId) { - return features[index[k][featureId]]; + return features[idx[k][featureId]]; } /** @@ -117,30 +117,30 @@ public class TreeDataIndex { projSize++; } - int[][] projection = new int[projSize][columnsCount()]; + int[][] prj = new int[projSize][columnsCount()]; for(int feature = 0; feature < columnsCount(); feature++) { int ptr = 0; for(int row = 0; row < rowsCount(); row++) { if(filter.test(featuresInSortedOrder(row, feature))) - projection[ptr++][feature] = index[row][feature]; + prj[ptr++][feature] = idx[row][feature]; } } - return new TreeDataIndex(projection, features, labels); + return new TreeDataIndex(prj, features, labels); } /** * @return count of rows in current index. */ public int rowsCount() { - return index.length; + return idx.length; } /** * @return count of columns in current index. */ public int columnsCount() { - return rowsCount() == 0 ? 0 : index[0].length ; + return rowsCount() == 0 ? 0 : idx[0].length; } /** @@ -168,9 +168,9 @@ public class TreeDataIndex { features[i][col] = features[j][col]; features[j][col] = tmpFeature; - int tmpLb = index[i][col]; - index[i][col] = index[j][col]; - index[j][col] = tmpLb; + int tmpLb = idx[i][col]; + idx[i][col] = idx[j][col]; + idx[j][col] = tmpLb; i++; j--; http://git-wip-us.apache.org/repos/asf/ignite/blob/eff5751e/modules/ml/src/main/java/org/apache/ignite/ml/tree/impurity/ImpurityMeasureCalculator.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/tree/impurity/ImpurityMeasureCalculator.java b/modules/ml/src/main/java/org/apache/ignite/ml/tree/impurity/ImpurityMeasureCalculator.java index 0c67535..b97e297 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/tree/impurity/ImpurityMeasureCalculator.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/tree/impurity/ImpurityMeasureCalculator.java @@ -32,15 +32,15 @@ import org.apache.ignite.ml.tree.impurity.util.StepFunction; */ public abstract class ImpurityMeasureCalculator<T extends ImpurityMeasure<T>> implements Serializable { /** Use index structure instead of using sorting while learning. */ - protected final boolean useIndex; + protected final boolean useIdx; /** * Constructs an instance of ImpurityMeasureCalculator. * - * @param useIndex Use index. + * @param useIdx Use index. */ - public ImpurityMeasureCalculator(boolean useIndex) { - this.useIndex = useIndex; + public ImpurityMeasureCalculator(boolean useIdx) { + this.useIdx = useIdx; } /** @@ -61,7 +61,7 @@ public abstract class ImpurityMeasureCalculator<T extends ImpurityMeasure<T>> im * @return Columns count in current dataset. */ protected int columnsCount(DecisionTreeData data, TreeDataIndex idx) { - return useIndex ? idx.columnsCount() : data.getFeatures()[0].length; + return useIdx ? idx.columnsCount() : data.getFeatures()[0].length; } /** @@ -72,7 +72,7 @@ public abstract class ImpurityMeasureCalculator<T extends ImpurityMeasure<T>> im * @return rows count in current dataset */ protected int rowsCount(DecisionTreeData data, TreeDataIndex idx) { - return useIndex ? idx.rowsCount() : data.getFeatures().length; + return useIdx ? idx.rowsCount() : data.getFeatures().length; } /** @@ -85,7 +85,7 @@ public abstract class ImpurityMeasureCalculator<T extends ImpurityMeasure<T>> im * @return label value in according to kth order statistic */ protected double getLabelValue(DecisionTreeData data, TreeDataIndex idx, int featureId, int k) { - return useIndex ? idx.labelInSortedOrder(k, featureId) : data.getLabels()[k]; + return useIdx ? idx.labelInSortedOrder(k, featureId) : data.getLabels()[k]; } /** @@ -98,10 +98,10 @@ public abstract class ImpurityMeasureCalculator<T extends ImpurityMeasure<T>> im * @return feature value in according to kth order statistic. */ protected double getFeatureValue(DecisionTreeData data, TreeDataIndex idx, int featureId, int k) { - return useIndex ? idx.featureInSortedOrder(k, featureId) : data.getFeatures()[k][featureId]; + return useIdx ? idx.featureInSortedOrder(k, featureId) : data.getFeatures()[k][featureId]; } protected Vector getFeatureValues(DecisionTreeData data, TreeDataIndex idx, int featureId, int k) { - return VectorUtils.of(useIndex ? idx.featuresInSortedOrder(k, featureId) : data.getFeatures()[k]); + return VectorUtils.of(useIdx ? idx.featuresInSortedOrder(k, featureId) : data.getFeatures()[k]); } } http://git-wip-us.apache.org/repos/asf/ignite/blob/eff5751e/modules/ml/src/main/java/org/apache/ignite/ml/tree/impurity/gini/GiniImpurityMeasureCalculator.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/tree/impurity/gini/GiniImpurityMeasureCalculator.java b/modules/ml/src/main/java/org/apache/ignite/ml/tree/impurity/gini/GiniImpurityMeasureCalculator.java index 38b3097..6a1eb0c 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/tree/impurity/gini/GiniImpurityMeasureCalculator.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/tree/impurity/gini/GiniImpurityMeasureCalculator.java @@ -39,22 +39,22 @@ public class GiniImpurityMeasureCalculator extends ImpurityMeasureCalculator<Gin * Constructs a new instance of Gini impurity measure calculator. * * @param lbEncoder Label encoder which defines integer value for every label class. - * @param useIndex Use index while calculate. + * @param useIdx Use index while calculate. */ - public GiniImpurityMeasureCalculator(Map<Double, Integer> lbEncoder, boolean useIndex) { - super(useIndex); + public GiniImpurityMeasureCalculator(Map<Double, Integer> lbEncoder, boolean useIdx) { + super(useIdx); this.lbEncoder = lbEncoder; } /** {@inheritDoc} */ @SuppressWarnings("unchecked") @Override public StepFunction<GiniImpurityMeasure>[] calculate(DecisionTreeData data, TreeFilter filter, int depth) { - TreeDataIndex index = null; + TreeDataIndex idx = null; boolean canCalculate = false; - if (useIndex) { - index = data.createIndexByFilter(depth, filter); - canCalculate = index.rowsCount() > 0; + if (useIdx) { + idx = data.createIndexByFilter(depth, filter); + canCalculate = idx.rowsCount() > 0; } else { data = data.filter(filter); @@ -62,47 +62,47 @@ public class GiniImpurityMeasureCalculator extends ImpurityMeasureCalculator<Gin } if (canCalculate) { - int rowsCnt = rowsCount(data, index); - int colsCnt = columnsCount(data, index); + int rowsCnt = rowsCount(data, idx); + int colsCnt = columnsCount(data, idx); StepFunction<GiniImpurityMeasure>[] res = new StepFunction[colsCnt]; long right[] = new long[lbEncoder.size()]; for (int i = 0; i < rowsCnt; i++) { - double lb = getLabelValue(data, index, 0, i); + double lb = getLabelValue(data, idx, 0, i); right[getLabelCode(lb)]++; } for (int col = 0; col < res.length; col++) { - if(!useIndex) + if (!useIdx) data.sort(col); double[] x = new double[rowsCnt + 1]; GiniImpurityMeasure[] y = new GiniImpurityMeasure[rowsCnt + 1]; long[] left = new long[lbEncoder.size()]; - long[] rightCopy = Arrays.copyOf(right, right.length); + long[] rightCp = Arrays.copyOf(right, right.length); int xPtr = 0, yPtr = 0; x[xPtr++] = Double.NEGATIVE_INFINITY; y[yPtr++] = new GiniImpurityMeasure( Arrays.copyOf(left, left.length), - Arrays.copyOf(rightCopy, rightCopy.length) + Arrays.copyOf(rightCp, rightCp.length) ); for (int i = 0; i < rowsCnt; i++) { - double lb = getLabelValue(data, index, col, i); + double lb = getLabelValue(data, idx, col, i); left[getLabelCode(lb)]++; - rightCopy[getLabelCode(lb)]--; + rightCp[getLabelCode(lb)]--; - double featureVal = getFeatureValue(data, index, col, i); - if (i < (rowsCnt - 1) && getFeatureValue(data, index, col, i + 1) == featureVal) + double featureVal = getFeatureValue(data, idx, col, i); + if (i < (rowsCnt - 1) && getFeatureValue(data, idx, col, i + 1) == featureVal) continue; x[xPtr++] = featureVal; y[yPtr++] = new GiniImpurityMeasure( Arrays.copyOf(left, left.length), - Arrays.copyOf(rightCopy, rightCopy.length) + Arrays.copyOf(rightCp, rightCp.length) ); } http://git-wip-us.apache.org/repos/asf/ignite/blob/eff5751e/modules/ml/src/main/java/org/apache/ignite/ml/tree/impurity/mse/MSEImpurityMeasureCalculator.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/tree/impurity/mse/MSEImpurityMeasureCalculator.java b/modules/ml/src/main/java/org/apache/ignite/ml/tree/impurity/mse/MSEImpurityMeasureCalculator.java index 1788737..3629768 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/tree/impurity/mse/MSEImpurityMeasureCalculator.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/tree/impurity/mse/MSEImpurityMeasureCalculator.java @@ -33,20 +33,20 @@ public class MSEImpurityMeasureCalculator extends ImpurityMeasureCalculator<MSEI /** * Constructs an instance of MSEImpurityMeasureCalculator. * - * @param useIndex Use index while calculate. + * @param useIdx Use index while calculate. */ - public MSEImpurityMeasureCalculator(boolean useIndex) { - super(useIndex); + public MSEImpurityMeasureCalculator(boolean useIdx) { + super(useIdx); } /** {@inheritDoc} */ @Override public StepFunction<MSEImpurityMeasure>[] calculate(DecisionTreeData data, TreeFilter filter, int depth) { - TreeDataIndex index = null; - boolean canCalculate = false; + TreeDataIndex idx = null; + boolean canCalculate; - if (useIndex) { - index = data.createIndexByFilter(depth, filter); - canCalculate = index.rowsCount() > 0; + if (useIdx) { + idx = data.createIndexByFilter(depth, filter); + canCalculate = idx.rowsCount() > 0; } else { data = data.filter(filter); @@ -54,8 +54,8 @@ public class MSEImpurityMeasureCalculator extends ImpurityMeasureCalculator<MSEI } if (canCalculate) { - int rowsCnt = rowsCount(data, index); - int colsCnt = columnsCount(data, index); + int rowsCnt = rowsCount(data, idx); + int colsCnt = columnsCount(data, idx); @SuppressWarnings("unchecked") StepFunction<MSEImpurityMeasure>[] res = new StepFunction[colsCnt]; @@ -63,14 +63,14 @@ public class MSEImpurityMeasureCalculator extends ImpurityMeasureCalculator<MSEI double rightYOriginal = 0; double rightY2Original = 0; for (int i = 0; i < rowsCnt; i++) { - double lbVal = getLabelValue(data, index, 0, i); + double lbVal = getLabelValue(data, idx, 0, i); rightYOriginal += lbVal; rightY2Original += Math.pow(lbVal, 2); } for (int col = 0; col < res.length; col++) { - if (!useIndex) + if (!useIdx) data.sort(col); double[] x = new double[rowsCnt + 1]; @@ -86,7 +86,7 @@ public class MSEImpurityMeasureCalculator extends ImpurityMeasureCalculator<MSEI int leftSize = 0; for (int i = 0; i <= rowsCnt; i++) { if (leftSize > 0) { - double lblVal = getLabelValue(data, index, col, i - 1); + double lblVal = getLabelValue(data, idx, col, i - 1); leftY += lblVal; leftY2 += Math.pow(lblVal, 2); @@ -96,7 +96,7 @@ public class MSEImpurityMeasureCalculator extends ImpurityMeasureCalculator<MSEI } if (leftSize < rowsCnt) - x[leftSize + 1] = getFeatureValue(data, index, col, i); + x[leftSize + 1] = getFeatureValue(data, idx, col, i); y[leftSize] = new MSEImpurityMeasure( leftY, leftY2, leftSize, rightY, rightY2, rowsCnt - leftSize http://git-wip-us.apache.org/repos/asf/ignite/blob/eff5751e/modules/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/RandomForestTrainer.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/RandomForestTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/RandomForestTrainer.java index c617d8d..4a83eb2 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/RandomForestTrainer.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/RandomForestTrainer.java @@ -73,7 +73,7 @@ public abstract class RandomForestTrainer<L, S extends ImpurityComputer<Bootstra private static final double BUCKET_SIZE_FACTOR = (1 / 10.0); /** Count of trees. */ - private int cntOfTrees = 1; + private int amountOfTrees = 1; /** Subsample size. */ private double subSampleSize = 1.0; @@ -115,7 +115,7 @@ public abstract class RandomForestTrainer<L, S extends ImpurityComputer<Bootstra List<TreeRoot> models = null; try (Dataset<EmptyContext, BootstrappedDatasetPartition> dataset = datasetBuilder.build( new EmptyContextBuilder<>(), - new BootstrappedDatasetBuilder<>(featureExtractor, lbExtractor, cntOfTrees, subSampleSize))) { + new BootstrappedDatasetBuilder<>(featureExtractor, lbExtractor, amountOfTrees, subSampleSize))) { if(!init(dataset)) return buildComposition(Collections.emptyList()); @@ -138,8 +138,8 @@ public abstract class RandomForestTrainer<L, S extends ImpurityComputer<Bootstra * @param cntOfTrees Count of trees. * @return an instance of current object with valid type in according to inheritance. */ - public T withCountOfTrees(int cntOfTrees) { - this.cntOfTrees = cntOfTrees; + public T withAmountOfTrees(int amountOfTrees) { + this.amountOfTrees = amountOfTrees; return instance(); } @@ -348,7 +348,7 @@ public abstract class RandomForestTrainer<L, S extends ImpurityComputer<Bootstra */ private Queue<TreeNode> createRootsQueue() { Queue<TreeNode> roots = new LinkedList<>(); - for (int i = 0; i < cntOfTrees; i++) + for (int i = 0; i < amountOfTrees; i++) roots.add(new TreeNode(1, i)); return roots; } http://git-wip-us.apache.org/repos/asf/ignite/blob/eff5751e/modules/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/data/NodeSplit.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/data/NodeSplit.java b/modules/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/data/NodeSplit.java index 52d0b74..3ccb568 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/data/NodeSplit.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/data/NodeSplit.java @@ -27,7 +27,7 @@ public class NodeSplit { private final int featureId; /** Feature split value. */ - private final double value; + private final double val; /** Impurity at this split point. */ private final double impurity; @@ -36,12 +36,12 @@ public class NodeSplit { * Creates an instance of NodeSplit. * * @param featureId Feature id. - * @param value Feature split value. + * @param val Feature split value. * @param impurity Impurity value. */ - public NodeSplit(int featureId, double value, double impurity) { + public NodeSplit(int featureId, double val, double impurity) { this.featureId = featureId; - this.value = value; + this.val = val; this.impurity = impurity; } @@ -52,7 +52,7 @@ public class NodeSplit { * @return list of children. */ public List<TreeNode> split(TreeNode node) { - List<TreeNode> children = node.toConditional(featureId, value); + List<TreeNode> children = node.toConditional(featureId, val); node.setImpurity(impurity); return children; } @@ -73,7 +73,7 @@ public class NodeSplit { } /** */ - public double getValue() { - return value; + public double getVal() { + return val; } } http://git-wip-us.apache.org/repos/asf/ignite/blob/eff5751e/modules/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/data/TreeNode.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/data/TreeNode.java b/modules/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/data/TreeNode.java index eb06143..528e31d 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/data/TreeNode.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/data/TreeNode.java @@ -51,7 +51,7 @@ public class TreeNode implements Model<Vector, Double>, Serializable { private int featureId; /** Value. */ - private double value; + private double val; /** Type. */ private Type type; @@ -76,7 +76,7 @@ public class TreeNode implements Model<Vector, Double>, Serializable { */ public TreeNode(long id, int treeId) { this.id = new NodeId(treeId, id); - this.value = -1; + this.val = -1; this.type = Type.UNKNOWN; this.impurity = Double.POSITIVE_INFINITY; this.depth = 1; @@ -87,9 +87,9 @@ public class TreeNode implements Model<Vector, Double>, Serializable { assert type != Type.UNKNOWN; if (type == Type.LEAF) - return value; + return val; else { - if (features.get(featureId) <= value) + if (features.get(featureId) <= val) return left.apply(features); else return right.apply(features); @@ -109,7 +109,7 @@ public class TreeNode implements Model<Vector, Double>, Serializable { case LEAF: return id; default: - if (features.get(featureId) <= value) + if (features.get(featureId) <= val) return left.predictNextNodeKey(features); else return right.predictNextNodeKey(features); @@ -120,12 +120,12 @@ public class TreeNode implements Model<Vector, Double>, Serializable { * Convert node to conditional node. * * @param featureId Feature id. - * @param value Value. + * @param val Value. */ - public List<TreeNode> toConditional(int featureId, double value) { + public List<TreeNode> toConditional(int featureId, double val) { assert type == Type.UNKNOWN; - toLeaf(value); + toLeaf(val); left = new TreeNode(2 * id.nodeId(), id.treeId()); right = new TreeNode(2 * id.nodeId() + 1, id.treeId()); this.type = Type.CONDITIONAL; @@ -138,12 +138,12 @@ public class TreeNode implements Model<Vector, Double>, Serializable { /** * Convert node to leaf. * - * @param value Value. + * @param val Value. */ - public void toLeaf(double value) { + public void toLeaf(double val) { assert type == Type.UNKNOWN; - this.value = value; + this.val = val; this.type = Type.LEAF; this.left = null; @@ -156,8 +156,8 @@ public class TreeNode implements Model<Vector, Double>, Serializable { } /** */ - public void setValue(double value) { - this.value = value; + public void setVal(double val) { + this.val = val; } /** */ http://git-wip-us.apache.org/repos/asf/ignite/blob/eff5751e/modules/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/data/impurity/ImpurityHistogramsComputer.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/data/impurity/ImpurityHistogramsComputer.java b/modules/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/data/impurity/ImpurityHistogramsComputer.java index d1ed87f..8320461 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/data/impurity/ImpurityHistogramsComputer.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/data/impurity/ImpurityHistogramsComputer.java @@ -183,9 +183,9 @@ public abstract class ImpurityHistogramsComputer<S extends ImpurityComputer<Boot */ private void addTo(Map<Integer, S> from, Map<Integer, S> to) { from.forEach((key, hist) -> { - if(!to.containsKey(key)) { + if (!to.containsKey(key)) to.put(key, hist); - } else { + else { S sumOfHists = to.get(key).plus(hist); to.put(key, sumOfHists); } http://git-wip-us.apache.org/repos/asf/ignite/blob/eff5751e/modules/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/data/statistics/LeafValuesComputer.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/data/statistics/LeafValuesComputer.java b/modules/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/data/statistics/LeafValuesComputer.java index 056eece..cd343ef 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/data/statistics/LeafValuesComputer.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/data/statistics/LeafValuesComputer.java @@ -65,7 +65,7 @@ public abstract class LeafValuesComputer<T> implements Serializable { T stat = stats.get(id); if(stat != null) { double leafVal = computeLeafValue(stat); - leaf.setValue(leafVal); + leaf.setVal(leafVal); } }); } http://git-wip-us.apache.org/repos/asf/ignite/blob/eff5751e/modules/ml/src/main/java/org/apache/ignite/ml/util/ModelTrace.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/util/ModelTrace.java b/modules/ml/src/main/java/org/apache/ignite/ml/util/ModelTrace.java index e6539d2..d34ab62 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/util/ModelTrace.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/util/ModelTrace.java @@ -68,10 +68,10 @@ public class ModelTrace { * Add field. * * @param name Name. - * @param value Value. + * @param val Value. */ - public ModelTrace addField(String name, String value) { - mdlFields.add(new IgniteBiTuple<>(name, value)); + public ModelTrace addField(String name, String val) { + mdlFields.add(new IgniteBiTuple<>(name, val)); return this; } http://git-wip-us.apache.org/repos/asf/ignite/blob/eff5751e/modules/ml/src/test/java/org/apache/ignite/ml/clustering/KMeansTrainerTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/clustering/KMeansTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/clustering/KMeansTrainerTest.java index 74ff8f1..205f0ff 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/clustering/KMeansTrainerTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/clustering/KMeansTrainerTest.java @@ -22,6 +22,7 @@ import java.util.HashMap; import java.util.Map; import org.apache.ignite.ml.clustering.kmeans.KMeansModel; import org.apache.ignite.ml.clustering.kmeans.KMeansTrainer; +import org.apache.ignite.ml.common.TrainerTest; import org.apache.ignite.ml.dataset.impl.local.LocalDatasetBuilder; import org.apache.ignite.ml.math.distances.EuclideanDistance; import org.apache.ignite.ml.math.primitives.vector.Vector; @@ -36,7 +37,7 @@ import static org.junit.Assert.assertTrue; /** * Tests for {@link KMeansTrainer}. */ -public class KMeansTrainerTest { +public class KMeansTrainerTest extends TrainerTest { /** Precision in test checks. */ private static final double PRECISION = 1e-2; @@ -59,7 +60,7 @@ public class KMeansTrainerTest { public void findOneClusters() { KMeansTrainer trainer = createAndCheckTrainer(); KMeansModel knnMdl = trainer.withAmountOfClusters(1).fit( - new LocalDatasetBuilder<>(data, 2), + new LocalDatasetBuilder<>(data, parts), (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 0, v.length - 1)), (k, v) -> v[2] ); @@ -77,19 +78,19 @@ public class KMeansTrainerTest { public void testUpdateMdl() { KMeansTrainer trainer = createAndCheckTrainer(); KMeansModel originalMdl = trainer.withAmountOfClusters(1).fit( - new LocalDatasetBuilder<>(data, 2), + new LocalDatasetBuilder<>(data, parts), (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 0, v.length - 1)), (k, v) -> v[2] ); KMeansModel updatedMdlOnSameDataset = trainer.update( originalMdl, - new LocalDatasetBuilder<>(data, 2), + new LocalDatasetBuilder<>(data, parts), (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 0, v.length - 1)), (k, v) -> v[2] ); KMeansModel updatedMdlOnEmptyDataset = trainer.update( originalMdl, - new LocalDatasetBuilder<>(new HashMap<Integer, double[]>(), 2), + new LocalDatasetBuilder<>(new HashMap<Integer, double[]>(), parts), (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 0, v.length - 1)), (k, v) -> v[2] ); http://git-wip-us.apache.org/repos/asf/ignite/blob/eff5751e/modules/ml/src/test/java/org/apache/ignite/ml/common/TrainerTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/common/TrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/common/TrainerTest.java index 678ed44..5d3bb5f 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/common/TrainerTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/common/TrainerTest.java @@ -28,7 +28,7 @@ import org.junit.runners.Parameterized; @RunWith(Parameterized.class) public class TrainerTest { /** Number of parts to be tested. */ - private static final int[] partsToBeTested = new int[]{1, 2, 3, 4, 5, 7, 100}; + private static final int[] partsToBeTested = new int[] {1, 2, 3, 4, 13}; /** Parameters. */ @Parameterized.Parameters(name = "Data divided on {0} partitions, training with batch size {1}") @@ -36,7 +36,7 @@ public class TrainerTest { List<Integer[]> res = new ArrayList<>(); for (int part : partsToBeTested) - res.add(new Integer[]{part}); + res.add(new Integer[] {part}); return res; } http://git-wip-us.apache.org/repos/asf/ignite/blob/eff5751e/modules/ml/src/test/java/org/apache/ignite/ml/composition/boosting/GDBTrainerTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/composition/boosting/GDBTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/composition/boosting/GDBTrainerTest.java index 4c3655b..4958b4b 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/composition/boosting/GDBTrainerTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/composition/boosting/GDBTrainerTest.java @@ -21,6 +21,7 @@ import java.util.HashMap; import java.util.Map; import java.util.function.BiFunction; import org.apache.ignite.ml.Model; +import org.apache.ignite.ml.common.TrainerTest; import org.apache.ignite.ml.composition.ModelsComposition; import org.apache.ignite.ml.composition.boosting.convergence.mean.MeanAbsValueConvergenceCheckerFactory; import org.apache.ignite.ml.composition.boosting.convergence.simple.ConvergenceCheckerStubFactory; @@ -38,7 +39,7 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; /** */ -public class GDBTrainerTest { +public class GDBTrainerTest extends TrainerTest { /** */ @Test public void testFitRegression() { http://git-wip-us.apache.org/repos/asf/ignite/blob/eff5751e/modules/ml/src/test/java/org/apache/ignite/ml/composition/predictionsaggregator/MeanValuePredictionsAggregatorTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/composition/predictionsaggregator/MeanValuePredictionsAggregatorTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/composition/predictionsaggregator/MeanValuePredictionsAggregatorTest.java index e738716..0d46361 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/composition/predictionsaggregator/MeanValuePredictionsAggregatorTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/composition/predictionsaggregator/MeanValuePredictionsAggregatorTest.java @@ -21,7 +21,9 @@ import org.junit.Test; import static org.junit.Assert.assertEquals; +/** */ public class MeanValuePredictionsAggregatorTest { + /** Aggregator. */ private PredictionsAggregator aggregator = new MeanValuePredictionsAggregator(); /** */ http://git-wip-us.apache.org/repos/asf/ignite/blob/eff5751e/modules/ml/src/test/java/org/apache/ignite/ml/composition/predictionsaggregator/OnMajorityPredictionsAggregatorTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/composition/predictionsaggregator/OnMajorityPredictionsAggregatorTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/composition/predictionsaggregator/OnMajorityPredictionsAggregatorTest.java index 8649b72..4d25a86 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/composition/predictionsaggregator/OnMajorityPredictionsAggregatorTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/composition/predictionsaggregator/OnMajorityPredictionsAggregatorTest.java @@ -22,6 +22,7 @@ import org.junit.Test; import static org.junit.Assert.assertEquals; public class OnMajorityPredictionsAggregatorTest { + /** Aggregator. */ private PredictionsAggregator aggregator = new OnMajorityPredictionsAggregator(); /** */ http://git-wip-us.apache.org/repos/asf/ignite/blob/eff5751e/modules/ml/src/test/java/org/apache/ignite/ml/dataset/feature/ObjectHistogramTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/dataset/feature/ObjectHistogramTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/dataset/feature/ObjectHistogramTest.java index 131b69b..9efb939 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/dataset/feature/ObjectHistogramTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/dataset/feature/ObjectHistogramTest.java @@ -74,10 +74,10 @@ public class ObjectHistogramTest { /** * @param hist History. - * @param expectedBuckets Expected buckets. - * @param expectedCounters Expected counters. + * @param expBuckets Expected buckets. + * @param expCounters Expected counters. */ - private void testBuckets(ObjectHistogram<Double> hist, int[] expectedBuckets, int[] expectedCounters) { + private void testBuckets(ObjectHistogram<Double> hist, int[] expBuckets, int[] expCounters) { int size = hist.buckets().size(); int[] buckets = new int[size]; int[] counters = new int[size]; @@ -87,8 +87,8 @@ public class ObjectHistogramTest { buckets[ptr++] = bucket; } - assertArrayEquals(expectedBuckets, buckets); - assertArrayEquals(expectedCounters, counters); + assertArrayEquals(expBuckets, buckets); + assertArrayEquals(expCounters, counters); } /** @@ -96,12 +96,12 @@ public class ObjectHistogramTest { */ @Test public void testAdd() { - double value = 100.; - hist1.addElement(value); - Optional<Double> counter = hist1.getValue(computeBucket(value)); + double val = 100.0; + hist1.addElement(val); + Optional<Double> cntr = hist1.getValue(computeBucket(val)); - assertTrue(counter.isPresent()); - assertEquals(1, counter.get().intValue()); + assertTrue(cntr.isPresent()); + assertEquals(1, cntr.get().intValue()); } /** @@ -109,8 +109,8 @@ public class ObjectHistogramTest { */ @Test public void testAddHist() { - ObjectHistogram<Double> result = hist1.plus(hist2); - testBuckets(result, new int[] {0, 1, 2, 3, 4, 5, 6}, new int[] {10, 8, 2, 1, 1, 2, 1}); + ObjectHistogram<Double> res = hist1.plus(hist2); + testBuckets(res, new int[] {0, 1, 2, 3, 4, 5, 6}, new int[] {10, 8, 2, 1, 1, 2, 1}); } /** @@ -133,18 +133,19 @@ public class ObjectHistogramTest { assertArrayEquals(new double[] {4., 7., 9., 10., 11., 12.}, sums, 0.01); } + /** */ @Test public void testOfSum() { IgniteFunction<Double, Integer> bucketMap = x -> (int) (Math.ceil(x * 100) % 100); - IgniteFunction<Double, Double> counterMap = x -> Math.pow(x, 2); + IgniteFunction<Double, Double> cntrMap = x -> Math.pow(x, 2); - ObjectHistogram<Double> forAllHistogram = new ObjectHistogram<>(bucketMap, counterMap); + ObjectHistogram<Double> forAllHistogram = new ObjectHistogram<>(bucketMap, cntrMap); Random rnd = new Random(); List<ObjectHistogram<Double>> partitions = new ArrayList<>(); int cntOfPartitions = rnd.nextInt(100); int sizeOfDataset = rnd.nextInt(10000); for(int i = 0; i < cntOfPartitions; i++) - partitions.add(new ObjectHistogram<>(bucketMap, counterMap)); + partitions.add(new ObjectHistogram<>(bucketMap, cntrMap)); for(int i = 0; i < sizeOfDataset; i++) { double objVal = rnd.nextDouble(); @@ -152,7 +153,7 @@ public class ObjectHistogramTest { partitions.get(rnd.nextInt(partitions.size())).addElement(objVal); } - Optional<ObjectHistogram<Double>> leftSum = partitions.stream().reduce((x,y) -> x.plus(y)); + Optional<ObjectHistogram<Double>> leftSum = partitions.stream().reduce(ObjectHistogram::plus); Optional<ObjectHistogram<Double>> rightSum = partitions.stream().reduce((x,y) -> y.plus(x)); assertTrue(leftSum.isPresent()); assertTrue(rightSum.isPresent()); @@ -162,9 +163,9 @@ public class ObjectHistogramTest { } /** - * @param value Value. + * @param val Value. */ - private int computeBucket(Double value) { - return (int)Math.rint(value); + private int computeBucket(Double val) { + return (int)Math.rint(val); } } http://git-wip-us.apache.org/repos/asf/ignite/blob/eff5751e/modules/ml/src/test/java/org/apache/ignite/ml/environment/LearningEnvironmentTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/environment/LearningEnvironmentTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/environment/LearningEnvironmentTest.java index 7e5a079..73192f0 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/environment/LearningEnvironmentTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/environment/LearningEnvironmentTest.java @@ -41,7 +41,7 @@ public class LearningEnvironmentTest { RandomForestRegressionTrainer trainer = new RandomForestRegressionTrainer( IntStream.range(0, 0).mapToObj( x -> new FeatureMeta("", 0, false)).collect(Collectors.toList()) - ).withCountOfTrees(101) + ).withAmountOfTrees(101) .withFeaturesCountSelectionStrgy(FeaturesCountSelectionStrategies.ONE_THIRD) .withMaxDepth(4) .withMinImpurityDelta(0.) http://git-wip-us.apache.org/repos/asf/ignite/blob/eff5751e/modules/ml/src/test/java/org/apache/ignite/ml/knn/ANNClassificationTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/knn/ANNClassificationTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/knn/ANNClassificationTest.java index 199644b..9c75824 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/knn/ANNClassificationTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/knn/ANNClassificationTest.java @@ -26,7 +26,6 @@ import org.apache.ignite.ml.knn.ann.ANNClassificationModel; import org.apache.ignite.ml.knn.ann.ANNClassificationTrainer; import org.apache.ignite.ml.knn.classification.NNStrategy; import org.apache.ignite.ml.math.distances.EuclideanDistance; -import org.apache.ignite.ml.math.primitives.vector.Vector; import org.apache.ignite.ml.math.primitives.vector.VectorUtils; import org.junit.Assert; import org.junit.Test; @@ -109,11 +108,16 @@ public class ANNClassificationTest extends TrainerTest { .withDistanceMeasure(new EuclideanDistance()) .withStrategy(NNStrategy.SIMPLE); - Vector v1 = VectorUtils.of(550, 550); - Vector v2 = VectorUtils.of(-550, -550); - TestUtils.assertEquals(originalMdl.apply(v1), updatedOnSameDataset.apply(v1), PRECISION); - TestUtils.assertEquals(originalMdl.apply(v2), updatedOnSameDataset.apply(v2), PRECISION); - TestUtils.assertEquals(originalMdl.apply(v1), updatedOnEmptyDataset.apply(v1), PRECISION); - TestUtils.assertEquals(originalMdl.apply(v2), updatedOnEmptyDataset.apply(v2), PRECISION); + Assert.assertNotNull(updatedOnSameDataset.getCandidates()); + + Assert.assertTrue(updatedOnSameDataset.toString().contains(NNStrategy.SIMPLE.name())); + Assert.assertTrue(updatedOnSameDataset.toString(true).contains(NNStrategy.SIMPLE.name())); + Assert.assertTrue(updatedOnSameDataset.toString(false).contains(NNStrategy.SIMPLE.name())); + + Assert.assertNotNull(updatedOnEmptyDataset.getCandidates()); + + Assert.assertTrue(updatedOnEmptyDataset.toString().contains(NNStrategy.SIMPLE.name())); + Assert.assertTrue(updatedOnEmptyDataset.toString(true).contains(NNStrategy.SIMPLE.name())); + Assert.assertTrue(updatedOnEmptyDataset.toString(false).contains(NNStrategy.SIMPLE.name())); } } http://git-wip-us.apache.org/repos/asf/ignite/blob/eff5751e/modules/ml/src/test/java/org/apache/ignite/ml/knn/KNNRegressionTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/knn/KNNRegressionTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/knn/KNNRegressionTest.java index 52ff1ec..9ff0bc2 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/knn/KNNRegressionTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/knn/KNNRegressionTest.java @@ -17,11 +17,10 @@ package org.apache.ignite.ml.knn; -import java.util.ArrayList; import java.util.Arrays; import java.util.HashMap; -import java.util.List; import java.util.Map; +import org.apache.ignite.ml.common.TrainerTest; import org.apache.ignite.ml.dataset.impl.local.LocalDatasetBuilder; import org.apache.ignite.ml.knn.classification.NNStrategy; import org.apache.ignite.ml.knn.regression.KNNRegressionModel; @@ -32,34 +31,13 @@ import org.apache.ignite.ml.math.primitives.vector.VectorUtils; import org.apache.ignite.ml.math.primitives.vector.impl.DenseVector; import org.junit.Assert; import org.junit.Test; -import org.junit.runner.RunWith; -import org.junit.runners.Parameterized; import static junit.framework.TestCase.assertEquals; /** * Tests for {@link KNNRegressionTrainer}. */ -@RunWith(Parameterized.class) -public class KNNRegressionTest { - /** Number of parts to be tested. */ - private static final int[] partsToBeTested = new int[] {1, 2, 3, 4, 5, 7, 100}; - - /** Number of partitions. */ - @Parameterized.Parameter - public int parts; - - /** Parameters. */ - @Parameterized.Parameters(name = "Data divided on {0} partitions, training with batch size {1}") - public static Iterable<Integer[]> data() { - List<Integer[]> res = new ArrayList<>(); - - for (int part : partsToBeTested) - res.add(new Integer[] {part}); - - return res; - } - +public class KNNRegressionTest extends TrainerTest { /** */ @Test public void testSimpleRegressionWithOneNeighbour() { http://git-wip-us.apache.org/repos/asf/ignite/blob/eff5751e/modules/ml/src/test/java/org/apache/ignite/ml/math/VectorUtilsTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/math/VectorUtilsTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/math/VectorUtilsTest.java index f8dc078..42d7efd 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/math/VectorUtilsTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/math/VectorUtilsTest.java @@ -23,6 +23,9 @@ import org.junit.Test; import static org.junit.Assert.assertEquals; +/** + * Tests for {@link VectorUtils } + */ public class VectorUtilsTest { /** */ @Test @@ -55,14 +58,12 @@ public class VectorUtilsTest { /** */ @Test(expected = NullPointerException.class) public void testFails1() { - double[] values = null; - VectorUtils.of(values); + VectorUtils.of((double[])null); } /** */ @Test(expected = NullPointerException.class) public void testFails2() { - Double[] values = null; - VectorUtils.of(values); + VectorUtils.of((Double[])null); } } http://git-wip-us.apache.org/repos/asf/ignite/blob/eff5751e/modules/ml/src/test/java/org/apache/ignite/ml/math/isolve/lsqr/LSQROnHeapTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/math/isolve/lsqr/LSQROnHeapTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/math/isolve/lsqr/LSQROnHeapTest.java index 6af03df..b720695 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/math/isolve/lsqr/LSQROnHeapTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/math/isolve/lsqr/LSQROnHeapTest.java @@ -20,13 +20,12 @@ package org.apache.ignite.ml.math.isolve.lsqr; import java.util.Arrays; import java.util.HashMap; import java.util.Map; +import org.apache.ignite.ml.common.TrainerTest; import org.apache.ignite.ml.dataset.DatasetBuilder; import org.apache.ignite.ml.dataset.impl.local.LocalDatasetBuilder; import org.apache.ignite.ml.dataset.primitive.builder.data.SimpleLabeledDatasetDataBuilder; import org.apache.ignite.ml.math.primitives.vector.VectorUtils; import org.junit.Test; -import org.junit.runner.RunWith; -import org.junit.runners.Parameterized; import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertEquals; @@ -35,26 +34,7 @@ import static org.junit.Assert.assertTrue; /** * Tests for {@link LSQROnHeap}. */ -@RunWith(Parameterized.class) -public class LSQROnHeapTest { - /** Parameters. */ - @Parameterized.Parameters(name = "Data divided on {0} partitions") - public static Iterable<Integer[]> data() { - return Arrays.asList( - new Integer[] {1}, - new Integer[] {2}, - new Integer[] {3}, - new Integer[] {5}, - new Integer[] {7}, - new Integer[] {100}, - new Integer[] {1000} - ); - } - - /** Number of partitions. */ - @Parameterized.Parameter - public int parts; - +public class LSQROnHeapTest extends TrainerTest { /** Tests solving simple linear system. */ @Test public void testSolveLinearSystem() { http://git-wip-us.apache.org/repos/asf/ignite/blob/eff5751e/modules/ml/src/test/java/org/apache/ignite/ml/pipeline/PipelineMdlTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/pipeline/PipelineMdlTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/pipeline/PipelineMdlTest.java index d740577..e59d515 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/pipeline/PipelineMdlTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/pipeline/PipelineMdlTest.java @@ -38,6 +38,11 @@ public class PipelineMdlTest { verifyPredict(getMdl(new LogisticRegressionModel(weights, 1.0).withRawLabels(true))); } + /** + * Get the empty internal model. + * + * @param internalMdl Internal model. + */ private PipelineMdl<Integer, double[]> getMdl(LogisticRegressionModel internalMdl) { return new PipelineMdl<Integer, double[]>() .withFeatureExtractor(null)