Repository: ignite
Updated Branches:
  refs/heads/master c9368da76 -> 414f45e0a


IGNITE-9065: Gradient boosting optimization

this closes #4486


Project: http://git-wip-us.apache.org/repos/asf/ignite/repo
Commit: http://git-wip-us.apache.org/repos/asf/ignite/commit/414f45e0
Tree: http://git-wip-us.apache.org/repos/asf/ignite/tree/414f45e0
Diff: http://git-wip-us.apache.org/repos/asf/ignite/diff/414f45e0

Branch: refs/heads/master
Commit: 414f45e0af39e1f7acf8304eedb113ca305e9a21
Parents: c9368da
Author: Alexey Platonov <[email protected]>
Authored: Wed Aug 8 13:22:26 2018 +0300
Committer: Yury Babak <[email protected]>
Committed: Wed Aug 8 13:22:26 2018 +0300

----------------------------------------------------------------------
 .../GDBOnTreesRegressionTrainerExample.java     | 116 ++++++++++++
 .../GRBOnTreesRegressionTrainerExample.java     | 116 ------------
 .../boosting/GDBLearningStrategy.java           | 178 +++++++++++++++++++
 .../ml/composition/boosting/GDBTrainer.java     |  48 ++---
 .../org/apache/ignite/ml/tree/DecisionTree.java |   8 +-
 .../tree/DecisionTreeClassificationTrainer.java |   2 +-
 .../ml/tree/DecisionTreeRegressionTrainer.java  |   2 +-
 .../GDBBinaryClassifierOnTreesTrainer.java      |  11 +-
 .../boosting/GDBOnTreesLearningStrategy.java    |  97 ++++++++++
 .../boosting/GDBRegressionOnTreesTrainer.java   |  11 +-
 .../ignite/ml/tree/data/DecisionTreeData.java   |  11 ++
 .../impurity/ImpurityMeasureCalculator.java     |   6 +
 12 files changed, 457 insertions(+), 149 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/ignite/blob/414f45e0/examples/src/main/java/org/apache/ignite/examples/ml/tree/boosting/GDBOnTreesRegressionTrainerExample.java
----------------------------------------------------------------------
diff --git 
a/examples/src/main/java/org/apache/ignite/examples/ml/tree/boosting/GDBOnTreesRegressionTrainerExample.java
 
b/examples/src/main/java/org/apache/ignite/examples/ml/tree/boosting/GDBOnTreesRegressionTrainerExample.java
new file mode 100644
index 0000000..fa7a0d4
--- /dev/null
+++ 
b/examples/src/main/java/org/apache/ignite/examples/ml/tree/boosting/GDBOnTreesRegressionTrainerExample.java
@@ -0,0 +1,116 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.examples.ml.tree.boosting;
+
+import org.apache.ignite.Ignite;
+import org.apache.ignite.IgniteCache;
+import org.apache.ignite.Ignition;
+import org.apache.ignite.cache.affinity.rendezvous.RendezvousAffinityFunction;
+import org.apache.ignite.configuration.CacheConfiguration;
+import org.apache.ignite.ml.Model;
+import org.apache.ignite.ml.math.primitives.vector.Vector;
+import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
+import org.apache.ignite.ml.trainers.DatasetTrainer;
+import org.apache.ignite.ml.tree.boosting.GDBRegressionOnTreesTrainer;
+import org.apache.ignite.thread.IgniteThread;
+import org.jetbrains.annotations.NotNull;
+
+/**
+ * Example represents a solution for the task of regression learning based on
+ * Gradient Boosting on trees implementation. It shows an initialization of 
{@link org.apache.ignite.ml.tree.boosting.GDBRegressionOnTreesTrainer},
+ * initialization of Ignite Cache, learning step and comparing of predicted 
and real values.
+ *
+ * In this example dataset is creating automatically by parabolic function 
f(x) = x^2.
+ */
+public class GDBOnTreesRegressionTrainerExample {
+    /**
+     * Run example.
+     *
+     * @param args Command line arguments, none required.
+     */
+    public static void main(String... args) throws InterruptedException {
+        // Start ignite grid.
+        try (Ignite ignite = 
Ignition.start("examples/config/example-ignite.xml")) {
+            System.out.println(">>> Ignite grid started.");
+
+            IgniteThread igniteThread = new 
IgniteThread(ignite.configuration().getIgniteInstanceName(),
+                GDBOnTreesRegressionTrainerExample.class.getSimpleName(), () 
-> {
+
+                // Create cache with training data.
+                CacheConfiguration<Integer, double[]> trainingSetCfg = 
createCacheConfiguration();
+                IgniteCache<Integer, double[]> trainingSet = 
fillTrainingData(ignite, trainingSetCfg);
+
+                // Create regression trainer.
+                DatasetTrainer<Model<Vector, Double>, Double> trainer = new 
GDBRegressionOnTreesTrainer(1.0, 2000, 1, 0.);
+
+                // Train decision tree model.
+                Model<Vector, Double> mdl = trainer.fit(
+                    ignite,
+                    trainingSet,
+                    (k, v) -> VectorUtils.of(v[0]),
+                    (k, v) -> v[1]
+                );
+
+                System.out.println(">>> ---------------------------------");
+                System.out.println(">>> | Prediction\t| Valid answer \t|");
+                System.out.println(">>> ---------------------------------");
+
+                // Calculate score.
+                for (int x = -5; x < 5; x++) {
+                    double predicted = mdl.apply(VectorUtils.of(x));
+
+                    System.out.printf(">>> | %.4f\t\t| %.4f\t\t|\n", 
predicted, Math.pow(x, 2));
+                }
+
+                System.out.println(">>> ---------------------------------");
+
+                System.out.println(">>> GDB Regression trainer example 
completed.");
+            });
+
+            igniteThread.start();
+            igniteThread.join();
+        }
+    }
+
+    /**
+     * Create cache configuration.
+     */
+    @NotNull private static CacheConfiguration<Integer, double[]> 
createCacheConfiguration() {
+        CacheConfiguration<Integer, double[]> trainingSetCfg = new 
CacheConfiguration<>();
+        trainingSetCfg.setName("TRAINING_SET");
+        trainingSetCfg.setAffinity(new RendezvousAffinityFunction(false, 10));
+        return trainingSetCfg;
+    }
+
+    /**
+     * Fill parabola training data.
+     *
+     * @param ignite Ignite.
+     * @param trainingSetCfg Training set config.
+     */
+    @NotNull private static IgniteCache<Integer, double[]> 
fillTrainingData(Ignite ignite,
+        CacheConfiguration<Integer, double[]> trainingSetCfg) {
+        IgniteCache<Integer, double[]> trainingSet = 
ignite.createCache(trainingSetCfg);
+        for(int i = -50; i <= 50; i++) {
+            double x = ((double)i) / 10.0;
+            double y = Math.pow(x, 2);
+            trainingSet.put(i, new double[] {x, y});
+        }
+        return trainingSet;
+    }
+}

http://git-wip-us.apache.org/repos/asf/ignite/blob/414f45e0/examples/src/main/java/org/apache/ignite/examples/ml/tree/boosting/GRBOnTreesRegressionTrainerExample.java
----------------------------------------------------------------------
diff --git 
a/examples/src/main/java/org/apache/ignite/examples/ml/tree/boosting/GRBOnTreesRegressionTrainerExample.java
 
b/examples/src/main/java/org/apache/ignite/examples/ml/tree/boosting/GRBOnTreesRegressionTrainerExample.java
deleted file mode 100644
index 71d405a..0000000
--- 
a/examples/src/main/java/org/apache/ignite/examples/ml/tree/boosting/GRBOnTreesRegressionTrainerExample.java
+++ /dev/null
@@ -1,116 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *      http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.ignite.examples.ml.tree.boosting;
-
-import org.apache.ignite.Ignite;
-import org.apache.ignite.IgniteCache;
-import org.apache.ignite.Ignition;
-import org.apache.ignite.cache.affinity.rendezvous.RendezvousAffinityFunction;
-import org.apache.ignite.configuration.CacheConfiguration;
-import org.apache.ignite.ml.Model;
-import org.apache.ignite.ml.math.primitives.vector.Vector;
-import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
-import org.apache.ignite.ml.trainers.DatasetTrainer;
-import org.apache.ignite.ml.tree.boosting.GDBRegressionOnTreesTrainer;
-import org.apache.ignite.thread.IgniteThread;
-import org.jetbrains.annotations.NotNull;
-
-/**
- * Example represents a solution for the task of regression learning based on
- * Gradient Boosting on trees implementation. It shows an initialization of 
{@link org.apache.ignite.ml.tree.boosting.GDBRegressionOnTreesTrainer},
- * initialization of Ignite Cache, learning step and comparing of predicted 
and real values.
- *
- * In this example dataset is creating automatically by parabolic function 
f(x) = x^2.
- */
-public class GRBOnTreesRegressionTrainerExample {
-    /**
-     * Run example.
-     *
-     * @param args Command line arguments, none required.
-     */
-    public static void main(String... args) throws InterruptedException {
-        // Start ignite grid.
-        try (Ignite ignite = 
Ignition.start("examples/config/example-ignite.xml")) {
-            System.out.println(">>> Ignite grid started.");
-
-            IgniteThread igniteThread = new 
IgniteThread(ignite.configuration().getIgniteInstanceName(),
-                GRBOnTreesRegressionTrainerExample.class.getSimpleName(), () 
-> {
-
-                // Create cache with training data.
-                CacheConfiguration<Integer, double[]> trainingSetCfg = 
createCacheConfiguration();
-                IgniteCache<Integer, double[]> trainingSet = 
fillTrainingData(ignite, trainingSetCfg);
-
-                // Create regression trainer.
-                DatasetTrainer<Model<Vector, Double>, Double> trainer = new 
GDBRegressionOnTreesTrainer(1.0, 2000, 1, 0.);
-
-                // Train decision tree model.
-                Model<Vector, Double> mdl = trainer.fit(
-                    ignite,
-                    trainingSet,
-                    (k, v) -> VectorUtils.of(v[0]),
-                    (k, v) -> v[1]
-                );
-
-                System.out.println(">>> ---------------------------------");
-                System.out.println(">>> | Prediction\t| Valid answer \t|");
-                System.out.println(">>> ---------------------------------");
-
-                // Calculate score.
-                for (int x = -5; x < 5; x++) {
-                    double predicted = mdl.apply(VectorUtils.of(x));
-
-                    System.out.printf(">>> | %.4f\t\t| %.4f\t\t|\n", 
predicted, Math.pow(x, 2));
-                }
-
-                System.out.println(">>> ---------------------------------");
-
-                System.out.println(">>> GDB Regression trainer example 
completed.");
-            });
-
-            igniteThread.start();
-            igniteThread.join();
-        }
-    }
-
-    /**
-     * Create cache configuration.
-     */
-    @NotNull private static CacheConfiguration<Integer, double[]> 
createCacheConfiguration() {
-        CacheConfiguration<Integer, double[]> trainingSetCfg = new 
CacheConfiguration<>();
-        trainingSetCfg.setName("TRAINING_SET");
-        trainingSetCfg.setAffinity(new RendezvousAffinityFunction(false, 10));
-        return trainingSetCfg;
-    }
-
-    /**
-     * Fill parabola training data.
-     *
-     * @param ignite Ignite.
-     * @param trainingSetCfg Training set config.
-     */
-    @NotNull private static IgniteCache<Integer, double[]> 
fillTrainingData(Ignite ignite,
-        CacheConfiguration<Integer, double[]> trainingSetCfg) {
-        IgniteCache<Integer, double[]> trainingSet = 
ignite.createCache(trainingSetCfg);
-        for(int i = -50; i <= 50; i++) {
-            double x = ((double)i) / 10.0;
-            double y = Math.pow(x, 2);
-            trainingSet.put(i, new double[] {x, y});
-        }
-        return trainingSet;
-    }
-}

http://git-wip-us.apache.org/repos/asf/ignite/blob/414f45e0/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/GDBLearningStrategy.java
----------------------------------------------------------------------
diff --git 
a/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/GDBLearningStrategy.java
 
b/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/GDBLearningStrategy.java
new file mode 100644
index 0000000..375748a
--- /dev/null
+++ 
b/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/GDBLearningStrategy.java
@@ -0,0 +1,178 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.composition.boosting;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+import org.apache.ignite.ml.Model;
+import org.apache.ignite.ml.composition.ModelsComposition;
+import 
org.apache.ignite.ml.composition.predictionsaggregator.WeightedPredictionsAggregator;
+import org.apache.ignite.ml.dataset.DatasetBuilder;
+import org.apache.ignite.ml.environment.LearningEnvironment;
+import org.apache.ignite.ml.environment.logging.MLLogger;
+import org.apache.ignite.ml.math.functions.IgniteBiFunction;
+import org.apache.ignite.ml.math.functions.IgniteFunction;
+import org.apache.ignite.ml.math.functions.IgniteSupplier;
+import org.apache.ignite.ml.math.functions.IgniteTriFunction;
+import org.apache.ignite.ml.math.primitives.vector.Vector;
+import org.apache.ignite.ml.trainers.DatasetTrainer;
+
+/**
+ * Learning strategy for gradient boosting.
+ */
+public class GDBLearningStrategy {
+    /** Learning environment. */
+    protected LearningEnvironment environment;
+
+    /** Count of iterations. */
+    protected int cntOfIterations;
+
+    /** Loss of gradient. */
+    protected IgniteTriFunction<Long, Double, Double, Double> lossGradient;
+
+    /** External label to internal mapping. */
+    protected IgniteFunction<Double, Double> externalLbToInternalMapping;
+
+    /** Base model trainer builder. */
+    protected IgniteSupplier<DatasetTrainer<? extends Model<Vector, Double>, 
Double>> baseMdlTrainerBuilder;
+
+    /** Mean label value. */
+    protected double meanLabelValue;
+
+    /** Sample size. */
+    protected long sampleSize;
+
+    /** Composition weights. */
+    protected double[] compositionWeights;
+
+    /**
+     * Implementation of gradient boosting iterations. At each step of 
iterations this algorithm
+     * build a regression model based on gradient of loss-function for current 
models composition.
+     *
+     * @param datasetBuilder Dataset builder.
+     * @param featureExtractor Feature extractor.
+     * @param lbExtractor Label extractor.
+     * @return list of learned models.
+     */
+    public <K, V> List<Model<Vector, Double>> learnModels(DatasetBuilder<K, V> 
datasetBuilder,
+        IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, 
V, Double> lbExtractor) {
+
+        List<Model<Vector, Double>> models = new ArrayList<>();
+        DatasetTrainer<? extends Model<Vector, Double>, Double> trainer = 
baseMdlTrainerBuilder.get();
+        for (int i = 0; i < cntOfIterations; i++) {
+            double[] weights = Arrays.copyOf(compositionWeights, i);
+
+            WeightedPredictionsAggregator aggregator = new 
WeightedPredictionsAggregator(weights, meanLabelValue);
+            Model<Vector, Double> currComposition = new 
ModelsComposition(models, aggregator);
+
+            IgniteBiFunction<K, V, Double> lbExtractorWrap = (k, v) -> {
+                Double realAnswer = 
externalLbToInternalMapping.apply(lbExtractor.apply(k, v));
+                Double mdlAnswer = 
currComposition.apply(featureExtractor.apply(k, v));
+                return -lossGradient.apply(sampleSize, realAnswer, mdlAnswer);
+            };
+
+            long startTs = System.currentTimeMillis();
+            models.add(trainer.fit(datasetBuilder, featureExtractor, 
lbExtractorWrap));
+            double learningTime = (double)(System.currentTimeMillis() - 
startTs) / 1000.0;
+            environment.logger(getClass()).log(MLLogger.VerboseLevel.LOW, "One 
model training time was %.2fs", learningTime);
+        }
+
+        return models;
+    }
+
+    /**
+     * Sets learning environment.
+     *
+     * @param environment Learning Environment.
+     */
+    public GDBLearningStrategy withEnvironment(LearningEnvironment 
environment) {
+        this.environment = environment;
+        return this;
+    }
+
+    /**
+     * Sets count of iterations.
+     *
+     * @param cntOfIterations Count of iterations.
+     */
+    public GDBLearningStrategy withCntOfIterations(int cntOfIterations) {
+        this.cntOfIterations = cntOfIterations;
+        return this;
+    }
+
+    /**
+     * Sets gradient of loss function.
+     *
+     * @param lossGradient Loss gradient.
+     */
+    public GDBLearningStrategy withLossGradient(IgniteTriFunction<Long, 
Double, Double, Double> lossGradient) {
+        this.lossGradient = lossGradient;
+        return this;
+    }
+
+    /**
+     * Sets external to internal label representation mapping.
+     *
+     * @param externalLbToInternal External label to internal.
+     */
+    public GDBLearningStrategy 
withExternalLabelToInternal(IgniteFunction<Double, Double> 
externalLbToInternal) {
+        this.externalLbToInternalMapping = externalLbToInternal;
+        return this;
+    }
+
+    /**
+     * Sets base model builder.
+     *
+     * @param buildBaseMdlTrainer Build base model trainer.
+     */
+    public GDBLearningStrategy 
withBaseModelTrainerBuilder(IgniteSupplier<DatasetTrainer<? extends 
Model<Vector, Double>, Double>> buildBaseMdlTrainer) {
+        this.baseMdlTrainerBuilder = buildBaseMdlTrainer;
+        return this;
+    }
+
+    /**
+     * Sets mean label value.
+     *
+     * @param meanLabelValue Mean label value.
+     */
+    public GDBLearningStrategy withMeanLabelValue(double meanLabelValue) {
+        this.meanLabelValue = meanLabelValue;
+        return this;
+    }
+
+    /**
+     * Sets sample size.
+     *
+     * @param sampleSize Sample size.
+     */
+    public GDBLearningStrategy withSampleSize(long sampleSize) {
+        this.sampleSize = sampleSize;
+        return this;
+    }
+
+    /**
+     * Sets composition weights vector.
+     *
+     * @param compositionWeights Composition weights.
+     */
+    public GDBLearningStrategy withCompositionWeights(double[] 
compositionWeights) {
+        this.compositionWeights = compositionWeights;
+        return this;
+    }
+}

http://git-wip-us.apache.org/repos/asf/ignite/blob/414f45e0/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/GDBTrainer.java
----------------------------------------------------------------------
diff --git 
a/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/GDBTrainer.java
 
b/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/GDBTrainer.java
index 8663d3d..5a0f52a 100644
--- 
a/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/GDBTrainer.java
+++ 
b/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/GDBTrainer.java
@@ -17,7 +17,6 @@
 
 package org.apache.ignite.ml.composition.boosting;
 
-import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.List;
 import org.apache.ignite.lang.IgniteBiTuple;
@@ -53,16 +52,18 @@ import org.jetbrains.annotations.NotNull;
  *
  * But in practice Decision Trees is most used regressors (see: {@link 
DecisionTreeRegressionTrainer}).
  */
-abstract class GDBTrainer extends DatasetTrainer<Model<Vector, Double>, 
Double> {
+public abstract class GDBTrainer extends DatasetTrainer<Model<Vector, Double>, 
Double> {
     /** Gradient step. */
     private final double gradientStep;
+
     /** Count of iterations. */
     private final int cntOfIterations;
+
     /**
      * Gradient of loss function. First argument is sample size, second 
argument is valid answer, third argument is
      * current model prediction.
      */
-    private final IgniteTriFunction<Long, Double, Double, Double> lossGradient;
+    protected final IgniteTriFunction<Long, Double, Double, Double> 
lossGradient;
 
     /**
      * Constructs GDBTrainer instance.
@@ -91,28 +92,23 @@ abstract class GDBTrainer extends 
DatasetTrainer<Model<Vector, Double>, Double>
         Double mean = initAndSampleSize.get1();
         Long sampleSize = initAndSampleSize.get2();
 
-        List<Model<Vector, Double>> models = new ArrayList<>();
         double[] compositionWeights = new double[cntOfIterations];
         Arrays.fill(compositionWeights, gradientStep);
         WeightedPredictionsAggregator resAggregator = new 
WeightedPredictionsAggregator(compositionWeights, mean);
 
         long learningStartTs = System.currentTimeMillis();
-        for (int i = 0; i < cntOfIterations; i++) {
-            double[] weights = Arrays.copyOf(compositionWeights, i);
-            WeightedPredictionsAggregator aggregator = new 
WeightedPredictionsAggregator(weights, mean);
-            Model<Vector, Double> currComposition = new 
ModelsComposition(models, aggregator);
-
-            IgniteBiFunction<K, V, Double> lbExtractorWrap = (k, v) -> {
-                Double realAnswer = 
externalLabelToInternal(lbExtractor.apply(k, v));
-                Double mdlAnswer = 
currComposition.apply(featureExtractor.apply(k, v));
-                return -lossGradient.apply(sampleSize, realAnswer, mdlAnswer);
-            };
-
-            long startTs = System.currentTimeMillis();
-            models.add(buildBaseModelTrainer().fit(datasetBuilder, 
featureExtractor, lbExtractorWrap));
-            double learningTime = (double)(System.currentTimeMillis() - 
startTs) / 1000.0;
-            environment.logger(getClass()).log(MLLogger.VerboseLevel.LOW, "One 
model training time was %.2fs", learningTime);
-        }
+
+        List<Model<Vector, Double>> models = getLearningStrategy()
+            .withBaseModelTrainerBuilder(this::buildBaseModelTrainer)
+            .withExternalLabelToInternal(this::externalLabelToInternal)
+            .withCntOfIterations(cntOfIterations)
+            .withCompositionWeights(compositionWeights)
+            .withEnvironment(environment)
+            .withLossGradient(lossGradient)
+            .withSampleSize(sampleSize)
+            .withMeanLabelValue(mean)
+            .learnModels(datasetBuilder, featureExtractor, lbExtractor);
+
         double learningTime = (double)(System.currentTimeMillis() - 
learningStartTs) / 1000.0;
         environment.logger(getClass()).log(MLLogger.VerboseLevel.LOW, "The 
training time was %.2fs", learningTime);
 
@@ -136,7 +132,8 @@ abstract class GDBTrainer extends 
DatasetTrainer<Model<Vector, Double>, Double>
     /**
      * Returns regressor model trainer for one step of GDB.
      */
-    @NotNull protected abstract DatasetTrainer<? extends Model<Vector, 
Double>, Double> buildBaseModelTrainer();
+    @NotNull
+    protected abstract DatasetTrainer<? extends Model<Vector, Double>, Double> 
buildBaseModelTrainer();
 
     /**
      * Maps external representation of label to internal.
@@ -191,4 +188,13 @@ abstract class GDBTrainer extends 
DatasetTrainer<Model<Vector, Double>, Double>
             throw new RuntimeException(e);
         }
     }
+
+    /**
+     * Returns learning strategy.
+     *
+     * @return learning strategy.
+     */
+    protected GDBLearningStrategy getLearningStrategy() {
+        return new GDBLearningStrategy();
+    }
 }

http://git-wip-us.apache.org/repos/asf/ignite/blob/414f45e0/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTree.java
----------------------------------------------------------------------
diff --git 
a/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTree.java 
b/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTree.java
index 270f14a..de8994a 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTree.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTree.java
@@ -79,20 +79,24 @@ public abstract class DecisionTree<T extends 
ImpurityMeasure<T>> extends Dataset
             new EmptyContextBuilder<>(),
             new DecisionTreeDataBuilder<>(featureExtractor, lbExtractor, 
useIndex)
         )) {
-            return split(dataset, e -> true, 0, 
getImpurityMeasureCalculator(dataset));
+            return fit(dataset);
         }
         catch (Exception e) {
             throw new RuntimeException(e);
         }
     }
 
+    public <K,V> DecisionTreeNode fit(Dataset<EmptyContext, DecisionTreeData> 
dataset) {
+        return split(dataset, e -> true, 0, 
getImpurityMeasureCalculator(dataset));
+    }
+
     /**
      * Returns impurity measure calculator.
      *
      * @param dataset Dataset.
      * @return Impurity measure calculator.
      */
-    abstract ImpurityMeasureCalculator<T> 
getImpurityMeasureCalculator(Dataset<EmptyContext, DecisionTreeData> dataset);
+    protected abstract ImpurityMeasureCalculator<T> 
getImpurityMeasureCalculator(Dataset<EmptyContext, DecisionTreeData> dataset);
 
     /**
      * Splits the node specified by the given dataset and predicate and 
returns decision tree node.

http://git-wip-us.apache.org/repos/asf/ignite/blob/414f45e0/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTreeClassificationTrainer.java
----------------------------------------------------------------------
diff --git 
a/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTreeClassificationTrainer.java
 
b/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTreeClassificationTrainer.java
index f371334..f8fc769 100644
--- 
a/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTreeClassificationTrainer.java
+++ 
b/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTreeClassificationTrainer.java
@@ -96,7 +96,7 @@ public class DecisionTreeClassificationTrainer extends 
DecisionTree<GiniImpurity
     }
 
     /** {@inheritDoc} */
-    @Override ImpurityMeasureCalculator<GiniImpurityMeasure> 
getImpurityMeasureCalculator(
+    @Override protected ImpurityMeasureCalculator<GiniImpurityMeasure> 
getImpurityMeasureCalculator(
         Dataset<EmptyContext, DecisionTreeData> dataset) {
         Set<Double> labels = dataset.compute(part -> {
 

http://git-wip-us.apache.org/repos/asf/ignite/blob/414f45e0/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTreeRegressionTrainer.java
----------------------------------------------------------------------
diff --git 
a/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTreeRegressionTrainer.java
 
b/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTreeRegressionTrainer.java
index 7446237..4c9aac9 100644
--- 
a/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTreeRegressionTrainer.java
+++ 
b/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTreeRegressionTrainer.java
@@ -64,7 +64,7 @@ public class DecisionTreeRegressionTrainer extends 
DecisionTree<MSEImpurityMeasu
     }
 
     /** {@inheritDoc} */
-    @Override ImpurityMeasureCalculator<MSEImpurityMeasure> 
getImpurityMeasureCalculator(
+    @Override protected ImpurityMeasureCalculator<MSEImpurityMeasure> 
getImpurityMeasureCalculator(
         Dataset<EmptyContext, DecisionTreeData> dataset) {
 
         return new MSEImpurityMeasureCalculator(useIndex);

http://git-wip-us.apache.org/repos/asf/ignite/blob/414f45e0/modules/ml/src/main/java/org/apache/ignite/ml/tree/boosting/GDBBinaryClassifierOnTreesTrainer.java
----------------------------------------------------------------------
diff --git 
a/modules/ml/src/main/java/org/apache/ignite/ml/tree/boosting/GDBBinaryClassifierOnTreesTrainer.java
 
b/modules/ml/src/main/java/org/apache/ignite/ml/tree/boosting/GDBBinaryClassifierOnTreesTrainer.java
index 631e848..4d87b47 100644
--- 
a/modules/ml/src/main/java/org/apache/ignite/ml/tree/boosting/GDBBinaryClassifierOnTreesTrainer.java
+++ 
b/modules/ml/src/main/java/org/apache/ignite/ml/tree/boosting/GDBBinaryClassifierOnTreesTrainer.java
@@ -17,10 +17,8 @@
 
 package org.apache.ignite.ml.tree.boosting;
 
-import org.apache.ignite.ml.Model;
 import org.apache.ignite.ml.composition.boosting.GDBBinaryClassifierTrainer;
-import org.apache.ignite.ml.math.primitives.vector.Vector;
-import org.apache.ignite.ml.trainers.DatasetTrainer;
+import org.apache.ignite.ml.composition.boosting.GDBLearningStrategy;
 import org.apache.ignite.ml.tree.DecisionTreeRegressionTrainer;
 import org.jetbrains.annotations.NotNull;
 
@@ -54,7 +52,7 @@ public class GDBBinaryClassifierOnTreesTrainer extends 
GDBBinaryClassifierTraine
     }
 
     /** {@inheritDoc} */
-    @NotNull @Override protected DatasetTrainer<? extends Model<Vector, 
Double>, Double> buildBaseModelTrainer() {
+    @NotNull @Override protected DecisionTreeRegressionTrainer 
buildBaseModelTrainer() {
         return new DecisionTreeRegressionTrainer(maxDepth, 
minImpurityDecrease).withUseIndex(useIndex);
     }
 
@@ -68,4 +66,9 @@ public class GDBBinaryClassifierOnTreesTrainer extends 
GDBBinaryClassifierTraine
         this.useIndex = useIndex;
         return this;
     }
+
+    /** {@inheritDoc} */
+    @Override protected GDBLearningStrategy getLearningStrategy() {
+        return new GDBOnTreesLearningStrategy(useIndex);
+    }
 }

http://git-wip-us.apache.org/repos/asf/ignite/blob/414f45e0/modules/ml/src/main/java/org/apache/ignite/ml/tree/boosting/GDBOnTreesLearningStrategy.java
----------------------------------------------------------------------
diff --git 
a/modules/ml/src/main/java/org/apache/ignite/ml/tree/boosting/GDBOnTreesLearningStrategy.java
 
b/modules/ml/src/main/java/org/apache/ignite/ml/tree/boosting/GDBOnTreesLearningStrategy.java
new file mode 100644
index 0000000..8589a79
--- /dev/null
+++ 
b/modules/ml/src/main/java/org/apache/ignite/ml/tree/boosting/GDBOnTreesLearningStrategy.java
@@ -0,0 +1,97 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.tree.boosting;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+import org.apache.ignite.ml.Model;
+import org.apache.ignite.ml.composition.ModelsComposition;
+import org.apache.ignite.ml.composition.boosting.GDBLearningStrategy;
+import 
org.apache.ignite.ml.composition.predictionsaggregator.WeightedPredictionsAggregator;
+import org.apache.ignite.ml.dataset.Dataset;
+import org.apache.ignite.ml.dataset.DatasetBuilder;
+import 
org.apache.ignite.ml.dataset.primitive.builder.context.EmptyContextBuilder;
+import org.apache.ignite.ml.dataset.primitive.context.EmptyContext;
+import org.apache.ignite.ml.environment.logging.MLLogger;
+import org.apache.ignite.ml.math.functions.IgniteBiFunction;
+import org.apache.ignite.ml.math.primitives.vector.Vector;
+import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
+import org.apache.ignite.ml.trainers.DatasetTrainer;
+import org.apache.ignite.ml.tree.DecisionTree;
+import org.apache.ignite.ml.tree.data.DecisionTreeData;
+import org.apache.ignite.ml.tree.data.DecisionTreeDataBuilder;
+
+/**
+ * Gradient boosting on trees specific learning strategy reusing learning 
dataset with index between
+ * several learning iterations.
+ */
+public class GDBOnTreesLearningStrategy  extends GDBLearningStrategy {
+    private boolean useIndex;
+
+    /**
+     * Create an instance of learning strategy.
+     *
+     * @param useIndex Use index.
+     */
+    public GDBOnTreesLearningStrategy(boolean useIndex) {
+        this.useIndex = useIndex;
+    }
+
+    /** {@inheritDoc} */
+    @Override public <K, V> List<Model<Vector, Double>> 
learnModels(DatasetBuilder<K, V> datasetBuilder,
+        IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, 
V, Double> lbExtractor) {
+
+        DatasetTrainer<? extends Model<Vector, Double>, Double> trainer = 
baseMdlTrainerBuilder.get();
+        assert trainer instanceof DecisionTree;
+        DecisionTree decisionTreeTrainer = (DecisionTree) trainer;
+
+        List<Model<Vector, Double>> models = new ArrayList<>();
+        try (Dataset<EmptyContext, DecisionTreeData> dataset = 
datasetBuilder.build(
+            new EmptyContextBuilder<>(),
+            new DecisionTreeDataBuilder<>(featureExtractor, lbExtractor, 
useIndex)
+        )) {
+            for (int i = 0; i < cntOfIterations; i++) {
+                double[] weights = Arrays.copyOf(compositionWeights, i);
+                WeightedPredictionsAggregator aggregator = new 
WeightedPredictionsAggregator(weights, meanLabelValue);
+                Model<Vector, Double> currComposition = new 
ModelsComposition(models, aggregator);
+
+                dataset.compute(part -> {
+                    if(part.getCopyOfOriginalLabels() == null)
+                        
part.setCopyOfOriginalLabels(Arrays.copyOf(part.getLabels(), 
part.getLabels().length));
+
+                    for(int j = 0; j < part.getLabels().length; j++) {
+                        double mdlAnswer = 
currComposition.apply(VectorUtils.of(part.getFeatures()[j]));
+                        double originalLbVal = 
externalLbToInternalMapping.apply(part.getCopyOfOriginalLabels()[j]);
+                        part.getLabels()[j] = -lossGradient.apply(sampleSize, 
originalLbVal, mdlAnswer);
+                    }
+                });
+
+                long startTs = System.currentTimeMillis();
+                models.add(decisionTreeTrainer.fit(dataset));
+                double learningTime = (double)(System.currentTimeMillis() - 
startTs) / 1000.0;
+                environment.logger(getClass()).log(MLLogger.VerboseLevel.LOW, 
"One model training time was %.2fs", learningTime);
+            }
+        }
+        catch (Exception e) {
+            throw new RuntimeException(e);
+        }
+
+        return models;
+    }
+}

http://git-wip-us.apache.org/repos/asf/ignite/blob/414f45e0/modules/ml/src/main/java/org/apache/ignite/ml/tree/boosting/GDBRegressionOnTreesTrainer.java
----------------------------------------------------------------------
diff --git 
a/modules/ml/src/main/java/org/apache/ignite/ml/tree/boosting/GDBRegressionOnTreesTrainer.java
 
b/modules/ml/src/main/java/org/apache/ignite/ml/tree/boosting/GDBRegressionOnTreesTrainer.java
index 450dae3..e2a183c 100644
--- 
a/modules/ml/src/main/java/org/apache/ignite/ml/tree/boosting/GDBRegressionOnTreesTrainer.java
+++ 
b/modules/ml/src/main/java/org/apache/ignite/ml/tree/boosting/GDBRegressionOnTreesTrainer.java
@@ -17,10 +17,8 @@
 
 package org.apache.ignite.ml.tree.boosting;
 
-import org.apache.ignite.ml.Model;
+import org.apache.ignite.ml.composition.boosting.GDBLearningStrategy;
 import org.apache.ignite.ml.composition.boosting.GDBRegressionTrainer;
-import org.apache.ignite.ml.math.primitives.vector.Vector;
-import org.apache.ignite.ml.trainers.DatasetTrainer;
 import org.apache.ignite.ml.tree.DecisionTreeRegressionTrainer;
 import org.jetbrains.annotations.NotNull;
 
@@ -54,7 +52,7 @@ public class GDBRegressionOnTreesTrainer extends 
GDBRegressionTrainer {
     }
 
     /** {@inheritDoc} */
-    @NotNull @Override protected DatasetTrainer<? extends Model<Vector, 
Double>, Double> buildBaseModelTrainer() {
+    @NotNull @Override protected DecisionTreeRegressionTrainer 
buildBaseModelTrainer() {
         return new DecisionTreeRegressionTrainer(maxDepth, 
minImpurityDecrease).withUseIndex(useIndex);
     }
 
@@ -68,4 +66,9 @@ public class GDBRegressionOnTreesTrainer extends 
GDBRegressionTrainer {
         this.useIndex = useIndex;
         return this;
     }
+
+    /** {@inheritDoc} */
+    @Override protected GDBLearningStrategy getLearningStrategy() {
+        return new GDBOnTreesLearningStrategy(useIndex);
+    }
 }

http://git-wip-us.apache.org/repos/asf/ignite/blob/414f45e0/modules/ml/src/main/java/org/apache/ignite/ml/tree/data/DecisionTreeData.java
----------------------------------------------------------------------
diff --git 
a/modules/ml/src/main/java/org/apache/ignite/ml/tree/data/DecisionTreeData.java 
b/modules/ml/src/main/java/org/apache/ignite/ml/tree/data/DecisionTreeData.java
index c017e5c..d5750ea 100644
--- 
a/modules/ml/src/main/java/org/apache/ignite/ml/tree/data/DecisionTreeData.java
+++ 
b/modules/ml/src/main/java/org/apache/ignite/ml/tree/data/DecisionTreeData.java
@@ -31,6 +31,9 @@ public class DecisionTreeData implements AutoCloseable {
     /** Vector with labels. */
     private final double[] labels;
 
+    /** Copy of vector with original labels. Auxiliary for Gradient Boosting 
on Trees.*/
+    private double[] copyOfOriginalLabels;
+
     /** Indexes cache. */
     private final List<TreeDataIndex> indexesCache;
 
@@ -137,6 +140,14 @@ public class DecisionTreeData implements AutoCloseable {
         return labels;
     }
 
+    public double[] getCopyOfOriginalLabels() {
+        return copyOfOriginalLabels;
+    }
+
+    public void setCopyOfOriginalLabels(double[] copyOfOriginalLabels) {
+        this.copyOfOriginalLabels = copyOfOriginalLabels;
+    }
+
     /** {@inheritDoc} */
     @Override public void close() {
         // Do nothing, GC will clean up.

http://git-wip-us.apache.org/repos/asf/ignite/blob/414f45e0/modules/ml/src/main/java/org/apache/ignite/ml/tree/impurity/ImpurityMeasureCalculator.java
----------------------------------------------------------------------
diff --git 
a/modules/ml/src/main/java/org/apache/ignite/ml/tree/impurity/ImpurityMeasureCalculator.java
 
b/modules/ml/src/main/java/org/apache/ignite/ml/tree/impurity/ImpurityMeasureCalculator.java
index 709f68e..0c67535 100644
--- 
a/modules/ml/src/main/java/org/apache/ignite/ml/tree/impurity/ImpurityMeasureCalculator.java
+++ 
b/modules/ml/src/main/java/org/apache/ignite/ml/tree/impurity/ImpurityMeasureCalculator.java
@@ -18,6 +18,8 @@
 package org.apache.ignite.ml.tree.impurity;
 
 import java.io.Serializable;
+import org.apache.ignite.ml.math.primitives.vector.Vector;
+import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
 import org.apache.ignite.ml.tree.TreeFilter;
 import org.apache.ignite.ml.tree.data.DecisionTreeData;
 import org.apache.ignite.ml.tree.data.TreeDataIndex;
@@ -98,4 +100,8 @@ public abstract class ImpurityMeasureCalculator<T extends 
ImpurityMeasure<T>> im
     protected double getFeatureValue(DecisionTreeData data, TreeDataIndex idx, 
int featureId, int k) {
         return useIndex ? idx.featureInSortedOrder(k, featureId) : 
data.getFeatures()[k][featureId];
     }
+
+    protected Vector getFeatureValues(DecisionTreeData data, TreeDataIndex 
idx, int featureId, int k) {
+        return VectorUtils.of(useIndex ? idx.featuresInSortedOrder(k, 
featureId) : data.getFeatures()[k]);
+    }
 }

Reply via email to