http://git-wip-us.apache.org/repos/asf/ignite/blob/9abfee69/modules/ml/src/main/java/org/apache/ignite/ml/tree/impurity/mse/MSEImpurityMeasure.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/tree/impurity/mse/MSEImpurityMeasure.java b/modules/ml/src/main/java/org/apache/ignite/ml/tree/impurity/mse/MSEImpurityMeasure.java new file mode 100644 index 0000000..3fc8515 --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/tree/impurity/mse/MSEImpurityMeasure.java @@ -0,0 +1,133 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.tree.impurity.mse; + +import org.apache.ignite.ml.tree.impurity.ImpurityMeasure; + +/** + * Mean squared error (variance) impurity measure which is calculated the following way: + * {@code \frac{1}{L}\sum_{i=0}^{n}(y_i - \mu)^2}. + */ +public class MSEImpurityMeasure implements ImpurityMeasure<MSEImpurityMeasure> { + /** */ + private static final long serialVersionUID = 4536394578628409689L; + + /** Sum of all elements in the left part. */ + private final double leftY; + + /** Sum of all squared elements in the left part. */ + private final double leftY2; + + /** Number of elements in the left part. */ + private final long leftCnt; + + /** Sum of all elements in the right part. */ + private final double rightY; + + /** Sum of all squared elements in the right part. */ + private final double rightY2; + + /** Number of elements in the right part. */ + private final long rightCnt; + + /** + * Constructs a new instance of mean squared error (variance) impurity measure. + * + * @param leftY Sum of all elements in the left part. + * @param leftY2 Sum of all squared elements in the left part. + * @param leftCnt Number of elements in the left part. + * @param rightY Sum of all elements in the right part. + * @param rightY2 Sum of all squared elements in the right part. + * @param rightCnt Number of elements in the right part. + */ + public MSEImpurityMeasure(double leftY, double leftY2, long leftCnt, double rightY, double rightY2, long rightCnt) { + this.leftY = leftY; + this.leftY2 = leftY2; + this.leftCnt = leftCnt; + this.rightY = rightY; + this.rightY2 = rightY2; + this.rightCnt = rightCnt; + } + + /** {@inheritDoc} */ + @Override public double impurity() { + double impurity = 0; + + if (leftCnt > 0) + impurity += leftY2 - 2.0 * leftY / leftCnt * leftY + Math.pow(leftY / leftCnt, 2) * leftCnt; + + if (rightCnt > 0) + impurity += rightY2 - 2.0 * rightY / rightCnt * rightY + Math.pow(rightY / rightCnt, 2) * rightCnt; + + return impurity; + } + + /** {@inheritDoc} */ + @Override public MSEImpurityMeasure add(MSEImpurityMeasure b) { + return new MSEImpurityMeasure( + leftY + b.leftY, + leftY2 + b.leftY2, + leftCnt + b.leftCnt, + rightY + b.rightY, + rightY2 + b.rightY2, + rightCnt + b.rightCnt + ); + } + + /** {@inheritDoc} */ + @Override public MSEImpurityMeasure subtract(MSEImpurityMeasure b) { + return new MSEImpurityMeasure( + leftY - b.leftY, + leftY2 - b.leftY2, + leftCnt - b.leftCnt, + rightY - b.rightY, + rightY2 - b.rightY2, + rightCnt - b.rightCnt + ); + } + + /** */ + public double getLeftY() { + return leftY; + } + + /** */ + public double getLeftY2() { + return leftY2; + } + + /** */ + public long getLeftCnt() { + return leftCnt; + } + + /** */ + public double getRightY() { + return rightY; + } + + /** */ + public double getRightY2() { + return rightY2; + } + + /** */ + public long getRightCnt() { + return rightCnt; + } +}
http://git-wip-us.apache.org/repos/asf/ignite/blob/9abfee69/modules/ml/src/main/java/org/apache/ignite/ml/tree/impurity/mse/MSEImpurityMeasureCalculator.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/tree/impurity/mse/MSEImpurityMeasureCalculator.java b/modules/ml/src/main/java/org/apache/ignite/ml/tree/impurity/mse/MSEImpurityMeasureCalculator.java new file mode 100644 index 0000000..cb5019c --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/tree/impurity/mse/MSEImpurityMeasureCalculator.java @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.tree.impurity.mse; + +import org.apache.ignite.ml.tree.data.DecisionTreeData; +import org.apache.ignite.ml.tree.impurity.ImpurityMeasureCalculator; +import org.apache.ignite.ml.tree.impurity.util.StepFunction; + +/** + * Meas squared error (variance) impurity measure calculator. + */ +public class MSEImpurityMeasureCalculator implements ImpurityMeasureCalculator<MSEImpurityMeasure> { + /** */ + private static final long serialVersionUID = 288747414953756824L; + + /** {@inheritDoc} */ + @SuppressWarnings("unchecked") + @Override public StepFunction<MSEImpurityMeasure>[] calculate(DecisionTreeData data) { + double[][] features = data.getFeatures(); + double[] labels = data.getLabels(); + + if (features.length > 0) { + StepFunction<MSEImpurityMeasure>[] res = new StepFunction[features[0].length]; + + for (int col = 0; col < res.length; col++) { + data.sort(col); + + double[] x = new double[features.length + 1]; + MSEImpurityMeasure[] y = new MSEImpurityMeasure[features.length + 1]; + + x[0] = Double.NEGATIVE_INFINITY; + + for (int leftSize = 0; leftSize <= features.length; leftSize++) { + double leftY = 0; + double leftY2 = 0; + double rightY = 0; + double rightY2 = 0; + + for (int i = 0; i < leftSize; i++) { + leftY += labels[i]; + leftY2 += Math.pow(labels[i], 2); + } + + for (int i = leftSize; i < features.length; i++) { + rightY += labels[i]; + rightY2 += Math.pow(labels[i], 2); + } + + if (leftSize < features.length) + x[leftSize + 1] = features[leftSize][col]; + + y[leftSize] = new MSEImpurityMeasure( + leftY, leftY2, leftSize, rightY, rightY2, features.length - leftSize + ); + } + + res[col] = new StepFunction<>(x, y); + } + + return res; + } + + return null; + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/9abfee69/modules/ml/src/main/java/org/apache/ignite/ml/tree/impurity/mse/package-info.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/tree/impurity/mse/package-info.java b/modules/ml/src/main/java/org/apache/ignite/ml/tree/impurity/mse/package-info.java new file mode 100644 index 0000000..23ec4e0 --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/tree/impurity/mse/package-info.java @@ -0,0 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * <!-- Package description. --> + * Contains mean squared error impurity measure and calculator. + */ +package org.apache.ignite.ml.tree.impurity.mse; \ No newline at end of file http://git-wip-us.apache.org/repos/asf/ignite/blob/9abfee69/modules/ml/src/main/java/org/apache/ignite/ml/tree/impurity/package-info.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/tree/impurity/package-info.java b/modules/ml/src/main/java/org/apache/ignite/ml/tree/impurity/package-info.java new file mode 100644 index 0000000..4155593 --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/tree/impurity/package-info.java @@ -0,0 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * <!-- Package description. --> + * Root package for decision tree impurity measures and calculators. + */ +package org.apache.ignite.ml.tree.impurity; \ No newline at end of file http://git-wip-us.apache.org/repos/asf/ignite/blob/9abfee69/modules/ml/src/main/java/org/apache/ignite/ml/tree/impurity/util/SimpleStepFunctionCompressor.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/tree/impurity/util/SimpleStepFunctionCompressor.java b/modules/ml/src/main/java/org/apache/ignite/ml/tree/impurity/util/SimpleStepFunctionCompressor.java new file mode 100644 index 0000000..2418571 --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/tree/impurity/util/SimpleStepFunctionCompressor.java @@ -0,0 +1,149 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.tree.impurity.util; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import org.apache.ignite.ml.tree.impurity.ImpurityMeasure; + +/** + * Simple step function compressor. + * + * @param <T> Type of step function values. + */ +public class SimpleStepFunctionCompressor<T extends ImpurityMeasure<T>> implements StepFunctionCompressor<T> { + /** */ + private static final long serialVersionUID = -3231787633598409157L; + + /** Min size of step function to be compressed. */ + private final int minSizeToBeCompressed; + + /** In case of compression min impurity increase that will be recorded. */ + private final double minImpurityIncreaseForRecord; + + /** In case of compression min impurity decrease that will be recorded. */ + private final double minImpurityDecreaseForRecord; + + /** + * Constructs a new instance of simple step function compressor with default parameters. + */ + public SimpleStepFunctionCompressor() { + this(10, 0.1, 0.05); + } + + /** + * Constructs a new instance of simple step function compressor. + * + * @param minSizeToBeCompressed Min size of step function to be compressed. + * @param minImpurityIncreaseForRecord In case of compression min impurity increase that will be recorded. + * @param minImpurityDecreaseForRecord In case of compression min impurity decrease that will be recorded. + */ + public SimpleStepFunctionCompressor(int minSizeToBeCompressed, double minImpurityIncreaseForRecord, + double minImpurityDecreaseForRecord) { + this.minSizeToBeCompressed = minSizeToBeCompressed; + this.minImpurityIncreaseForRecord = minImpurityIncreaseForRecord; + this.minImpurityDecreaseForRecord = minImpurityDecreaseForRecord; + } + + /** {@inheritDoc} */ + @Override public StepFunction<T> compress(StepFunction<T> function) { + double[] arguments = function.getX(); + T[] values = function.getY(); + + if (arguments.length >= minSizeToBeCompressed) { + List<StepFunctionPoint> points = new ArrayList<>(); + + for (int i = 0; i < arguments.length; i++) + points.add(new StepFunctionPoint(arguments[i], values[i])); + + points = compress(points); + + double[] resX = new double[points.size()]; + T[] resY = Arrays.copyOf(values, points.size()); + + for (int i = 0; i < points.size(); i++) { + StepFunctionPoint pnt = points.get(i); + resX[i] = pnt.x; + resY[i] = pnt.y; + } + + return new StepFunction<>(resX, resY); + } + + return function; + } + + /** + * Compresses list of step function points. + * + * @param points Step function points. + * @return Compressed step function points. + */ + private List<StepFunctionPoint> compress(List<StepFunctionPoint> points) { + List<StepFunctionPoint> res = new ArrayList<>(); + + double minImpurity = Double.MAX_VALUE, maxImpurity = Double.MIN_VALUE; + for (int i = 0; i < points.size(); i++) { + StepFunctionPoint pnt = points.get(i); + + double impurity = pnt.y.impurity(); + + if (impurity > maxImpurity) + maxImpurity = impurity; + + if (impurity < minImpurity) + minImpurity = impurity; + } + + Double prev = null; + for (StepFunctionPoint pnt : points) { + double impurity = (pnt.y.impurity() - minImpurity) / (maxImpurity - minImpurity); + if (prev == null || + prev - impurity >= minImpurityDecreaseForRecord || + impurity - prev >= minImpurityIncreaseForRecord) { + prev = impurity; + res.add(pnt); + } + } + + return res; + } + + /** + * Util class that represents step function point. + */ + private class StepFunctionPoint { + /** Argument of the step start. */ + private final double x; + + /** Value of the step. */ + private final T y; + + /** + * Constructs a new instance of util class that represents step function point. + * + * @param x Argument of the step start. + * @param y Value of the step. + */ + StepFunctionPoint(double x, T y) { + this.x = x; + this.y = y; + } + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/9abfee69/modules/ml/src/main/java/org/apache/ignite/ml/tree/impurity/util/StepFunction.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/tree/impurity/util/StepFunction.java b/modules/ml/src/main/java/org/apache/ignite/ml/tree/impurity/util/StepFunction.java new file mode 100644 index 0000000..431503d --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/tree/impurity/util/StepFunction.java @@ -0,0 +1,162 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.tree.impurity.util; + +import java.util.Arrays; +import org.apache.ignite.ml.tree.impurity.ImpurityMeasure; + +/** + * Step function described by {@code x} and {@code y} points. + * + * @param <T> Type of function values. + */ +public class StepFunction<T extends ImpurityMeasure<T>> { + /** Argument of every steps start. Should be ascendingly sorted all the time. */ + private final double[] x; + + /** Value of every step. */ + private final T[] y; + + /** + * Constructs a new instance of step function. + * + * @param x Argument of every steps start. + * @param y Value of every step. + */ + public StepFunction(double[] x, T[] y) { + assert x.length == y.length : "Argument and value arrays have to be the same length"; + + this.x = x; + this.y = y; + + sort(x, y, 0, x.length - 1); + } + + /** + * Adds the given step function to this. + * + * @param b Another step function. + * @return Sum of this and the given function. + */ + public StepFunction<T> add(StepFunction<T> b) { + int resSize = 0, leftPtr = 0, rightPtr = 0; + double previousPnt = 0; + + while (leftPtr < x.length || rightPtr < b.x.length) { + if (rightPtr >= b.x.length || (leftPtr < x.length && x[leftPtr] < b.x[rightPtr])) { + if (resSize == 0 || x[leftPtr] != previousPnt) { + previousPnt = x[leftPtr]; + resSize++; + } + + leftPtr++; + } + else { + if (resSize == 0 || b.x[rightPtr] != previousPnt) { + previousPnt = b.x[rightPtr]; + resSize++; + } + + rightPtr++; + } + } + + double[] resX = new double[resSize]; + T[] resY = Arrays.copyOf(y, resSize); + + leftPtr = 0; + rightPtr = 0; + + for (int i = 0; leftPtr < x.length || rightPtr < b.x.length; i++) { + if (rightPtr >= b.x.length || (leftPtr < x.length && x[leftPtr] < b.x[rightPtr])) { + boolean override = i > 0 && x[leftPtr] == resX[i - 1]; + int target = override ? i - 1 : i; + + resY[target] = override ? resY[target] : null; + resY[target] = i > 0 ? resY[i - 1] : null; + resY[target] = resY[target] == null ? y[leftPtr] : resY[target].add(y[leftPtr]); + + if (leftPtr > 0) + resY[target] = resY[target].subtract(y[leftPtr - 1]); + + resX[target] = x[leftPtr]; + i = target; + + leftPtr++; + } + else { + boolean override = i > 0 && b.x[rightPtr] == resX[i - 1]; + int target = override ? i - 1 : i; + + resY[target] = override ? resY[target] : null; + resY[target] = i > 0 ? resY[i - 1] : null; + + resY[target] = resY[target] == null ? b.y[rightPtr] : resY[target].add(b.y[rightPtr]); + + if (rightPtr > 0) + resY[target] = resY[target].subtract(b.y[rightPtr - 1]); + + resX[target] = b.x[rightPtr]; + i = target; + + rightPtr++; + } + } + + return new StepFunction<>(resX, resY); + } + + /** */ + private void sort(double[] x, T[] y, int from, int to) { + if (from < to) { + double pivot = x[(from + to) / 2]; + + int i = from, j = to; + while (i <= j) { + while (x[i] < pivot) i++; + while (x[j] > pivot) j--; + + if (i <= j) { + double tmpX = x[i]; + x[i] = x[j]; + x[j] = tmpX; + + T tmpY = y[i]; + y[i] = y[j]; + y[j] = tmpY; + + i++; + j--; + } + } + + sort(x, y, from, j); + sort(x, y, i, to); + } + } + + /** */ + public double[] getX() { + return x; + } + + /** */ + public T[] getY() { + return y; + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/9abfee69/modules/ml/src/main/java/org/apache/ignite/ml/tree/impurity/util/StepFunctionCompressor.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/tree/impurity/util/StepFunctionCompressor.java b/modules/ml/src/main/java/org/apache/ignite/ml/tree/impurity/util/StepFunctionCompressor.java new file mode 100644 index 0000000..41baa29 --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/tree/impurity/util/StepFunctionCompressor.java @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.tree.impurity.util; + +import java.io.Serializable; +import java.util.Arrays; +import org.apache.ignite.ml.tree.impurity.ImpurityMeasure; + +/** + * Base interface for step function compressors which reduces step function size. + * + * @param <T> Type of step function value. + */ +public interface StepFunctionCompressor<T extends ImpurityMeasure<T>> extends Serializable { + /** + * Compresses the given step function. + * + * @param function Step function. + * @return Compressed step function. + */ + public StepFunction<T> compress(StepFunction<T> function); + + /** + * Compresses every step function in the given array. + * + * @param functions Array of step functions. + * @return Arrays of compressed step function. + */ + default public StepFunction<T>[] compress(StepFunction<T>[] functions) { + if (functions == null) + return null; + + StepFunction<T>[] res = Arrays.copyOf(functions, functions.length); + + for (int i = 0; i < res.length; i++) + res[i] = compress(res[i]); + + return res; + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/9abfee69/modules/ml/src/main/java/org/apache/ignite/ml/tree/impurity/util/package-info.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/tree/impurity/util/package-info.java b/modules/ml/src/main/java/org/apache/ignite/ml/tree/impurity/util/package-info.java new file mode 100644 index 0000000..99df618 --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/tree/impurity/util/package-info.java @@ -0,0 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * <!-- Package description. --> + * Contains util classes used in decision tree impurity calculators. + */ +package org.apache.ignite.ml.tree.impurity.util; \ No newline at end of file http://git-wip-us.apache.org/repos/asf/ignite/blob/9abfee69/modules/ml/src/main/java/org/apache/ignite/ml/tree/leaf/DecisionTreeLeafBuilder.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/tree/leaf/DecisionTreeLeafBuilder.java b/modules/ml/src/main/java/org/apache/ignite/ml/tree/leaf/DecisionTreeLeafBuilder.java new file mode 100644 index 0000000..976e30d --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/tree/leaf/DecisionTreeLeafBuilder.java @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.tree.leaf; + +import org.apache.ignite.ml.dataset.Dataset; +import org.apache.ignite.ml.dataset.primitive.context.EmptyContext; +import org.apache.ignite.ml.tree.DecisionTreeLeafNode; +import org.apache.ignite.ml.tree.TreeFilter; +import org.apache.ignite.ml.tree.data.DecisionTreeData; + +/** + * Base interface for decision tree leaf builders. + */ +public interface DecisionTreeLeafBuilder { + /** + * Creates new leaf node for given dataset and node predicate. + * + * @param dataset Dataset. + * @param pred Node predicate. + * @return Leaf node. + */ + public DecisionTreeLeafNode createLeafNode(Dataset<EmptyContext, DecisionTreeData> dataset, TreeFilter pred); +} http://git-wip-us.apache.org/repos/asf/ignite/blob/9abfee69/modules/ml/src/main/java/org/apache/ignite/ml/tree/leaf/MeanDecisionTreeLeafBuilder.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/tree/leaf/MeanDecisionTreeLeafBuilder.java b/modules/ml/src/main/java/org/apache/ignite/ml/tree/leaf/MeanDecisionTreeLeafBuilder.java new file mode 100644 index 0000000..2e05215 --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/tree/leaf/MeanDecisionTreeLeafBuilder.java @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.tree.leaf; + +import org.apache.ignite.ml.dataset.Dataset; +import org.apache.ignite.ml.dataset.primitive.context.EmptyContext; +import org.apache.ignite.ml.tree.DecisionTreeLeafNode; +import org.apache.ignite.ml.tree.TreeFilter; +import org.apache.ignite.ml.tree.data.DecisionTreeData; + +/** + * Decision tree leaf node builder that chooses mean value as a leaf value. + */ +public class MeanDecisionTreeLeafBuilder implements DecisionTreeLeafBuilder { + /** {@inheritDoc} */ + @Override public DecisionTreeLeafNode createLeafNode(Dataset<EmptyContext, DecisionTreeData> dataset, + TreeFilter pred) { + double[] aa = dataset.compute(part -> { + double mean = 0; + int cnt = 0; + + for (int i = 0; i < part.getFeatures().length; i++) { + if (pred.test(part.getFeatures()[i])) { + mean += part.getLabels()[i]; + cnt++; + } + } + + if (cnt != 0) { + mean = mean / cnt; + + return new double[] {mean, cnt}; + } + + return null; + }, this::reduce); + + return aa != null ? new DecisionTreeLeafNode(aa[0]) : null; + } + + /** */ + private double[] reduce(double[] a, double[] b) { + if (a == null) + return b; + else if (b == null) + return a; + else { + double aMean = a[0]; + double aCnt = a[1]; + double bMean = b[0]; + double bCnt = b[1]; + + double mean = (aMean * aCnt + bMean * bCnt) / (aCnt + bCnt); + + return new double[] {mean, aCnt + bCnt}; + } + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/9abfee69/modules/ml/src/main/java/org/apache/ignite/ml/tree/leaf/MostCommonDecisionTreeLeafBuilder.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/tree/leaf/MostCommonDecisionTreeLeafBuilder.java b/modules/ml/src/main/java/org/apache/ignite/ml/tree/leaf/MostCommonDecisionTreeLeafBuilder.java new file mode 100644 index 0000000..1e8b941 --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/tree/leaf/MostCommonDecisionTreeLeafBuilder.java @@ -0,0 +1,86 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.tree.leaf; + +import java.util.HashMap; +import java.util.Map; +import org.apache.ignite.ml.dataset.Dataset; +import org.apache.ignite.ml.dataset.primitive.context.EmptyContext; +import org.apache.ignite.ml.tree.DecisionTreeLeafNode; +import org.apache.ignite.ml.tree.TreeFilter; +import org.apache.ignite.ml.tree.data.DecisionTreeData; + +/** + * Decision tree leaf node builder that chooses most common value as a leaf node value. + */ +public class MostCommonDecisionTreeLeafBuilder implements DecisionTreeLeafBuilder { + /** {@inheritDoc} */ + @Override public DecisionTreeLeafNode createLeafNode(Dataset<EmptyContext, DecisionTreeData> dataset, + TreeFilter pred) { + Map<Double, Integer> cnt = dataset.compute(part -> { + + if (part.getFeatures() != null) { + Map<Double, Integer> map = new HashMap<>(); + + for (int i = 0; i < part.getFeatures().length; i++) { + if (pred.test(part.getFeatures()[i])) { + double lb = part.getLabels()[i]; + + if (map.containsKey(lb)) + map.put(lb, map.get(lb) + 1); + else + map.put(lb, 1); + } + } + + return map; + } + + return null; + }, this::reduce); + + double bestVal = 0; + int bestCnt = -1; + + for (Map.Entry<Double, Integer> e : cnt.entrySet()) { + if (e.getValue() > bestCnt) { + bestCnt = e.getValue(); + bestVal = e.getKey(); + } + } + + return new DecisionTreeLeafNode(bestVal); + } + + /** */ + private Map<Double, Integer> reduce(Map<Double, Integer> a, Map<Double, Integer> b) { + if (a == null) + return b; + else if (b == null) + return a; + else { + for (Map.Entry<Double, Integer> e : b.entrySet()) { + if (a.containsKey(e.getKey())) + a.put(e.getKey(), a.get(e.getKey()) + e.getValue()); + else + a.put(e.getKey(), e.getValue()); + } + return a; + } + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/9abfee69/modules/ml/src/main/java/org/apache/ignite/ml/tree/leaf/package-info.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/tree/leaf/package-info.java b/modules/ml/src/main/java/org/apache/ignite/ml/tree/leaf/package-info.java new file mode 100644 index 0000000..26ec67d --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/tree/leaf/package-info.java @@ -0,0 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * <!-- Package description. --> + * Root package for decision trees leaf builders. + */ +package org.apache.ignite.ml.tree.leaf; \ No newline at end of file http://git-wip-us.apache.org/repos/asf/ignite/blob/9abfee69/modules/ml/src/main/java/org/apache/ignite/ml/tree/package-info.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/tree/package-info.java b/modules/ml/src/main/java/org/apache/ignite/ml/tree/package-info.java new file mode 100644 index 0000000..660f3f3 --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/tree/package-info.java @@ -0,0 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * <!-- Package description. --> + * Root package for decision trees. + */ +package org.apache.ignite.ml.tree; \ No newline at end of file http://git-wip-us.apache.org/repos/asf/ignite/blob/9abfee69/modules/ml/src/main/java/org/apache/ignite/ml/trees/CategoricalRegionInfo.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trees/CategoricalRegionInfo.java b/modules/ml/src/main/java/org/apache/ignite/ml/trees/CategoricalRegionInfo.java deleted file mode 100644 index 3ae474e..0000000 --- a/modules/ml/src/main/java/org/apache/ignite/ml/trees/CategoricalRegionInfo.java +++ /dev/null @@ -1,72 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.ignite.ml.trees; - -import java.io.Externalizable; -import java.io.IOException; -import java.io.ObjectInput; -import java.io.ObjectOutput; -import java.util.BitSet; - -/** - * Information about categorical region. - */ -public class CategoricalRegionInfo extends RegionInfo implements Externalizable { - /** - * Bitset representing categories of this region. - */ - private BitSet cats; - - /** - * @param impurity Impurity of region. - * @param cats Bitset representing categories of this region. - */ - public CategoricalRegionInfo(double impurity, BitSet cats) { - super(impurity); - - this.cats = cats; - } - - /** - * No-op constructor for serialization/deserialization. - */ - public CategoricalRegionInfo() { - // No-op - } - - /** - * Get bitset representing categories of this region. - * - * @return Bitset representing categories of this region. - */ - public BitSet cats() { - return cats; - } - - /** {@inheritDoc} */ - @Override public void writeExternal(ObjectOutput out) throws IOException { - super.writeExternal(out); - out.writeObject(cats); - } - - /** {@inheritDoc} */ - @Override public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException { - super.readExternal(in); - cats = (BitSet)in.readObject(); - } -} http://git-wip-us.apache.org/repos/asf/ignite/blob/9abfee69/modules/ml/src/main/java/org/apache/ignite/ml/trees/CategoricalSplitInfo.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trees/CategoricalSplitInfo.java b/modules/ml/src/main/java/org/apache/ignite/ml/trees/CategoricalSplitInfo.java deleted file mode 100644 index 94cb1e8..0000000 --- a/modules/ml/src/main/java/org/apache/ignite/ml/trees/CategoricalSplitInfo.java +++ /dev/null @@ -1,68 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.ignite.ml.trees; - -import java.util.BitSet; -import org.apache.ignite.ml.trees.nodes.CategoricalSplitNode; -import org.apache.ignite.ml.trees.nodes.SplitNode; -import org.apache.ignite.ml.trees.trainers.columnbased.vectors.SplitInfo; - -/** - * Information about split of categorical feature. - * - * @param <D> Class representing information of left and right subregions. - */ -public class CategoricalSplitInfo<D extends RegionInfo> extends SplitInfo<D> { - /** Bitset indicating which vectors are assigned to left subregion. */ - private final BitSet bs; - - /** - * @param regionIdx Index of region which is split. - * @param leftData Data of left subregion. - * @param rightData Data of right subregion. - * @param bs Bitset indicating which vectors are assigned to left subregion. - */ - public CategoricalSplitInfo(int regionIdx, D leftData, D rightData, - BitSet bs) { - super(regionIdx, leftData, rightData); - this.bs = bs; - } - - /** {@inheritDoc} */ - @Override public SplitNode createSplitNode(int featureIdx) { - return new CategoricalSplitNode(featureIdx, bs); - } - - /** - * Get bitset indicating which vectors are assigned to left subregion. - */ - public BitSet bitSet() { - return bs; - } - - /** {@inheritDoc} */ - @Override public String toString() { - return "CategoricalSplitInfo [" + - "infoGain=" + infoGain + - ", regionIdx=" + regionIdx + - ", leftData=" + leftData + - ", bs=" + bs + - ", rightData=" + rightData + - ']'; - } -} http://git-wip-us.apache.org/repos/asf/ignite/blob/9abfee69/modules/ml/src/main/java/org/apache/ignite/ml/trees/ContinuousRegionInfo.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trees/ContinuousRegionInfo.java b/modules/ml/src/main/java/org/apache/ignite/ml/trees/ContinuousRegionInfo.java deleted file mode 100644 index e98bb72..0000000 --- a/modules/ml/src/main/java/org/apache/ignite/ml/trees/ContinuousRegionInfo.java +++ /dev/null @@ -1,74 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.ignite.ml.trees; - -import java.io.IOException; -import java.io.ObjectInput; -import java.io.ObjectOutput; - -/** - * Information about region used by continuous features. - */ -public class ContinuousRegionInfo extends RegionInfo { - /** - * Count of samples in this region. - */ - private int size; - - /** - * @param impurity Impurity of the region. - * @param size Size of this region - */ - public ContinuousRegionInfo(double impurity, int size) { - super(impurity); - this.size = size; - } - - /** - * No-op constructor for serialization/deserialization. - */ - public ContinuousRegionInfo() { - // No-op - } - - /** - * Get the size of region. - */ - public int getSize() { - return size; - } - - /** {@inheritDoc} */ - @Override public String toString() { - return "ContinuousRegionInfo [" + - "size=" + size + - ']'; - } - - /** {@inheritDoc} */ - @Override public void writeExternal(ObjectOutput out) throws IOException { - super.writeExternal(out); - out.writeInt(size); - } - - /** {@inheritDoc} */ - @Override public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException { - super.readExternal(in); - size = in.readInt(); - } -} \ No newline at end of file http://git-wip-us.apache.org/repos/asf/ignite/blob/9abfee69/modules/ml/src/main/java/org/apache/ignite/ml/trees/ContinuousSplitCalculator.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trees/ContinuousSplitCalculator.java b/modules/ml/src/main/java/org/apache/ignite/ml/trees/ContinuousSplitCalculator.java deleted file mode 100644 index 3a0e9da..0000000 --- a/modules/ml/src/main/java/org/apache/ignite/ml/trees/ContinuousSplitCalculator.java +++ /dev/null @@ -1,51 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.ignite.ml.trees; - -import java.util.stream.DoubleStream; -import org.apache.ignite.ml.trees.trainers.columnbased.vectors.SplitInfo; -import org.apache.ignite.ml.trees.trainers.columnbased.vectors.ContinuousFeatureProcessor; - -/** - * This class is used for calculation of best split by continuous feature. - * - * @param <C> Class in which information about region will be stored. - */ -public interface ContinuousSplitCalculator<C extends ContinuousRegionInfo> { - /** - * Calculate region info 'from scratch'. - * - * @param s Stream of labels in this region. - * @param l Index of sample projection on this feature in array sorted by this projection value and intervals - * bitsets. ({@link ContinuousFeatureProcessor}). - * @return Region info. - */ - C calculateRegionInfo(DoubleStream s, int l); - - /** - * Calculate split info of best split of region given information about this region. - * - * @param sampleIndexes Indexes of samples of this region. - * @param values All values of this feature. - * @param labels All labels of this feature. - * @param regionIdx Index of region being split. - * @param data Information about region being split which can be used for computations. - * @return Information about best split of region with index given by regionIdx. - */ - SplitInfo<C> splitRegion(Integer[] sampleIndexes, double[] values, double[] labels, int regionIdx, C data); -} http://git-wip-us.apache.org/repos/asf/ignite/blob/9abfee69/modules/ml/src/main/java/org/apache/ignite/ml/trees/RegionInfo.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trees/RegionInfo.java b/modules/ml/src/main/java/org/apache/ignite/ml/trees/RegionInfo.java deleted file mode 100644 index 8ec7db3..0000000 --- a/modules/ml/src/main/java/org/apache/ignite/ml/trees/RegionInfo.java +++ /dev/null @@ -1,62 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.ignite.ml.trees; - -import java.io.Externalizable; -import java.io.IOException; -import java.io.ObjectInput; -import java.io.ObjectOutput; - -/** Class containing information about region. */ -public class RegionInfo implements Externalizable { - /** Impurity in this region. */ - private double impurity; - - /** - * @param impurity Impurity of this region. - */ - public RegionInfo(double impurity) { - this.impurity = impurity; - } - - /** - * No-op constructor for serialization/deserialization. - */ - public RegionInfo() { - // No-op - } - - /** - * Get impurity in this region. - * - * @return Impurity of this region. - */ - public double impurity() { - return impurity; - } - - /** {@inheritDoc} */ - @Override public void writeExternal(ObjectOutput out) throws IOException { - out.writeDouble(impurity); - } - - /** {@inheritDoc} */ - @Override public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException { - impurity = in.readDouble(); - } -} http://git-wip-us.apache.org/repos/asf/ignite/blob/9abfee69/modules/ml/src/main/java/org/apache/ignite/ml/trees/models/DecisionTreeModel.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trees/models/DecisionTreeModel.java b/modules/ml/src/main/java/org/apache/ignite/ml/trees/models/DecisionTreeModel.java deleted file mode 100644 index 572e64a..0000000 --- a/modules/ml/src/main/java/org/apache/ignite/ml/trees/models/DecisionTreeModel.java +++ /dev/null @@ -1,44 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.ignite.ml.trees.models; - -import org.apache.ignite.ml.Model; -import org.apache.ignite.ml.math.Vector; -import org.apache.ignite.ml.trees.nodes.DecisionTreeNode; - -/** - * Model for decision tree. - */ -public class DecisionTreeModel implements Model<Vector, Double> { - /** Root node of the decision tree. */ - private final DecisionTreeNode root; - - /** - * Construct decision tree model. - * - * @param root Root of decision tree. - */ - public DecisionTreeModel(DecisionTreeNode root) { - this.root = root; - } - - /** {@inheritDoc} */ - @Override public Double apply(Vector val) { - return root.process(val); - } -} http://git-wip-us.apache.org/repos/asf/ignite/blob/9abfee69/modules/ml/src/main/java/org/apache/ignite/ml/trees/models/package-info.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trees/models/package-info.java b/modules/ml/src/main/java/org/apache/ignite/ml/trees/models/package-info.java deleted file mode 100644 index ce8418e..0000000 --- a/modules/ml/src/main/java/org/apache/ignite/ml/trees/models/package-info.java +++ /dev/null @@ -1,22 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/** - * <!-- Package description. --> - * Contains decision tree models. - */ -package org.apache.ignite.ml.trees.models; \ No newline at end of file http://git-wip-us.apache.org/repos/asf/ignite/blob/9abfee69/modules/ml/src/main/java/org/apache/ignite/ml/trees/nodes/CategoricalSplitNode.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trees/nodes/CategoricalSplitNode.java b/modules/ml/src/main/java/org/apache/ignite/ml/trees/nodes/CategoricalSplitNode.java deleted file mode 100644 index cae6d4a..0000000 --- a/modules/ml/src/main/java/org/apache/ignite/ml/trees/nodes/CategoricalSplitNode.java +++ /dev/null @@ -1,50 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.ignite.ml.trees.nodes; - -import java.util.BitSet; -import org.apache.ignite.ml.math.Vector; - -/** - * Split node by categorical feature. - */ -public class CategoricalSplitNode extends SplitNode { - /** Bitset specifying which categories belong to left subregion. */ - private final BitSet bs; - - /** - * Construct categorical split node. - * - * @param featureIdx Index of feature by which split is done. - * @param bs Bitset specifying which categories go to the left subtree. - */ - public CategoricalSplitNode(int featureIdx, BitSet bs) { - super(featureIdx); - this.bs = bs; - } - - /** {@inheritDoc} */ - @Override public boolean goLeft(Vector v) { - return bs.get((int)v.getX(featureIdx)); - } - - /** {@inheritDoc} */ - @Override public String toString() { - return "CategoricalSplitNode [bs=" + bs + ']'; - } -} http://git-wip-us.apache.org/repos/asf/ignite/blob/9abfee69/modules/ml/src/main/java/org/apache/ignite/ml/trees/nodes/ContinuousSplitNode.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trees/nodes/ContinuousSplitNode.java b/modules/ml/src/main/java/org/apache/ignite/ml/trees/nodes/ContinuousSplitNode.java deleted file mode 100644 index 285cfcd..0000000 --- a/modules/ml/src/main/java/org/apache/ignite/ml/trees/nodes/ContinuousSplitNode.java +++ /dev/null @@ -1,56 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.ignite.ml.trees.nodes; - -import org.apache.ignite.ml.math.Vector; - -/** - * Split node representing split of continuous feature. - */ -public class ContinuousSplitNode extends SplitNode { - /** Threshold. Values which are less or equal then threshold are assigned to the left subregion. */ - private final double threshold; - - /** - * Construct ContinuousSplitNode by threshold and feature index. - * - * @param threshold Threshold. - * @param featureIdx Feature index. - */ - public ContinuousSplitNode(double threshold, int featureIdx) { - super(featureIdx); - this.threshold = threshold; - } - - /** {@inheritDoc} */ - @Override public boolean goLeft(Vector v) { - return v.getX(featureIdx) <= threshold; - } - - /** Threshold. Values which are less or equal then threshold are assigned to the left subregion. */ - public double threshold() { - return threshold; - } - - /** {@inheritDoc} */ - @Override public String toString() { - return "ContinuousSplitNode [" + - "threshold=" + threshold + - ']'; - } -} http://git-wip-us.apache.org/repos/asf/ignite/blob/9abfee69/modules/ml/src/main/java/org/apache/ignite/ml/trees/nodes/DecisionTreeNode.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trees/nodes/DecisionTreeNode.java b/modules/ml/src/main/java/org/apache/ignite/ml/trees/nodes/DecisionTreeNode.java deleted file mode 100644 index d31623d..0000000 --- a/modules/ml/src/main/java/org/apache/ignite/ml/trees/nodes/DecisionTreeNode.java +++ /dev/null @@ -1,33 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.ignite.ml.trees.nodes; - -import org.apache.ignite.ml.math.Vector; - -/** - * Node of decision tree. - */ -public interface DecisionTreeNode { - /** - * Assign the double value to the given vector. - * - * @param v Vector. - * @return Value assigned to the given vector. - */ - double process(Vector v); -} http://git-wip-us.apache.org/repos/asf/ignite/blob/9abfee69/modules/ml/src/main/java/org/apache/ignite/ml/trees/nodes/Leaf.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trees/nodes/Leaf.java b/modules/ml/src/main/java/org/apache/ignite/ml/trees/nodes/Leaf.java deleted file mode 100644 index 79b441f..0000000 --- a/modules/ml/src/main/java/org/apache/ignite/ml/trees/nodes/Leaf.java +++ /dev/null @@ -1,49 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.ignite.ml.trees.nodes; - -import org.apache.ignite.ml.math.Vector; - -/** - * Terminal node of the decision tree. - */ -public class Leaf implements DecisionTreeNode { - /** - * Value in subregion represented by this node. - */ - private final double val; - - /** - * Construct the leaf of decision tree. - * - * @param val Value in subregion represented by this node. - */ - public Leaf(double val) { - this.val = val; - } - - /** - * Return value in subregion represented by this node. - * - * @param v Vector. - * @return Value in subregion represented by this node. - */ - @Override public double process(Vector v) { - return val; - } -} http://git-wip-us.apache.org/repos/asf/ignite/blob/9abfee69/modules/ml/src/main/java/org/apache/ignite/ml/trees/nodes/SplitNode.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trees/nodes/SplitNode.java b/modules/ml/src/main/java/org/apache/ignite/ml/trees/nodes/SplitNode.java deleted file mode 100644 index 4c258d1..0000000 --- a/modules/ml/src/main/java/org/apache/ignite/ml/trees/nodes/SplitNode.java +++ /dev/null @@ -1,100 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.ignite.ml.trees.nodes; - -import org.apache.ignite.ml.math.Vector; - -/** - * Node in decision tree representing a split. - */ -public abstract class SplitNode implements DecisionTreeNode { - /** Left subtree. */ - protected DecisionTreeNode l; - - /** Right subtree. */ - protected DecisionTreeNode r; - - /** Feature index. */ - protected final int featureIdx; - - /** - * Constructs SplitNode with a given feature index. - * - * @param featureIdx Feature index. - */ - public SplitNode(int featureIdx) { - this.featureIdx = featureIdx; - } - - /** - * Indicates if the given vector is in left subtree. - * - * @param v Vector - * @return Status of given vector being left subtree. - */ - abstract boolean goLeft(Vector v); - - /** - * Left subtree. - * - * @return Left subtree. - */ - public DecisionTreeNode left() { - return l; - } - - /** - * Right subtree. - * - * @return Right subtree. - */ - public DecisionTreeNode right() { - return r; - } - - /** - * Set the left subtree. - * - * @param n left subtree. - */ - public void setLeft(DecisionTreeNode n) { - l = n; - } - - /** - * Set the right subtree. - * - * @param n right subtree. - */ - public void setRight(DecisionTreeNode n) { - r = n; - } - - /** - * Delegates processing to subtrees. - * - * @param v Vector. - * @return Value assigned to the given vector. - */ - @Override public double process(Vector v) { - if (left() != null && goLeft(v)) - return left().process(v); - else - return right().process(v); - } -} http://git-wip-us.apache.org/repos/asf/ignite/blob/9abfee69/modules/ml/src/main/java/org/apache/ignite/ml/trees/nodes/package-info.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trees/nodes/package-info.java b/modules/ml/src/main/java/org/apache/ignite/ml/trees/nodes/package-info.java deleted file mode 100644 index d6deb9d..0000000 --- a/modules/ml/src/main/java/org/apache/ignite/ml/trees/nodes/package-info.java +++ /dev/null @@ -1,22 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/** - * <!-- Package description. --> - * Contains classes representing decision tree nodes. - */ -package org.apache.ignite.ml.trees.nodes; \ No newline at end of file http://git-wip-us.apache.org/repos/asf/ignite/blob/9abfee69/modules/ml/src/main/java/org/apache/ignite/ml/trees/package-info.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trees/package-info.java b/modules/ml/src/main/java/org/apache/ignite/ml/trees/package-info.java deleted file mode 100644 index b07ba4a..0000000 --- a/modules/ml/src/main/java/org/apache/ignite/ml/trees/package-info.java +++ /dev/null @@ -1,22 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/** - * <!-- Package description. --> - * Contains decision tree algorithms. - */ -package org.apache.ignite.ml.trees; \ No newline at end of file http://git-wip-us.apache.org/repos/asf/ignite/blob/9abfee69/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/BiIndex.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/BiIndex.java b/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/BiIndex.java deleted file mode 100644 index 0d27c8a..0000000 --- a/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/BiIndex.java +++ /dev/null @@ -1,113 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.ignite.ml.trees.trainers.columnbased; - -import java.io.Externalizable; -import java.io.IOException; -import java.io.ObjectInput; -import java.io.ObjectOutput; -import org.apache.ignite.cache.affinity.AffinityKeyMapped; - -/** - * Class representing a simple index in 2d matrix in the form (row, col). - */ -public class BiIndex implements Externalizable { - /** Row. */ - private int row; - - /** Column. */ - @AffinityKeyMapped - private int col; - - /** - * No-op constructor for serialization/deserialization. - */ - public BiIndex() { - // No-op. - } - - /** - * Construct BiIndex from row and column. - * - * @param row Row. - * @param col Column. - */ - public BiIndex(int row, int col) { - this.row = row; - this.col = col; - } - - /** - * Returns row. - * - * @return Row. - */ - public int row() { - return row; - } - - /** - * Returns column. - * - * @return Column. - */ - public int col() { - return col; - } - - /** {@inheritDoc} */ - @Override public boolean equals(Object o) { - if (this == o) - return true; - if (o == null || getClass() != o.getClass()) - return false; - - BiIndex idx = (BiIndex)o; - - if (row != idx.row) - return false; - return col == idx.col; - } - - /** {@inheritDoc} */ - @Override public int hashCode() { - int res = row; - res = 31 * res + col; - return res; - } - - /** {@inheritDoc} */ - @Override public String toString() { - return "BiIndex [" + - "row=" + row + - ", col=" + col + - ']'; - } - - /** {@inheritDoc} */ - @Override public void writeExternal(ObjectOutput out) throws IOException { - out.writeInt(row); - out.writeInt(col); - } - - /** {@inheritDoc} */ - @Override public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException { - row = in.readInt(); - col = in.readInt(); - } -} http://git-wip-us.apache.org/repos/asf/ignite/blob/9abfee69/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/BiIndexedCacheColumnDecisionTreeTrainerInput.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/BiIndexedCacheColumnDecisionTreeTrainerInput.java b/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/BiIndexedCacheColumnDecisionTreeTrainerInput.java deleted file mode 100644 index 04281fb..0000000 --- a/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/BiIndexedCacheColumnDecisionTreeTrainerInput.java +++ /dev/null @@ -1,57 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.ignite.ml.trees.trainers.columnbased; - -import java.util.Map; -import java.util.stream.DoubleStream; -import java.util.stream.IntStream; -import java.util.stream.Stream; -import org.apache.ignite.Ignite; -import org.apache.ignite.IgniteCache; -import org.apache.ignite.lang.IgniteBiTuple; - -/** - * Adapter for column decision tree trainer for bi-indexed cache. - */ -public class BiIndexedCacheColumnDecisionTreeTrainerInput extends CacheColumnDecisionTreeTrainerInput<BiIndex, Double> { - /** - * Construct an input for {@link ColumnDecisionTreeTrainer}. - * - * @param cache Bi-indexed cache. - * @param catFeaturesInfo Information about categorical feature in the form (feature index -> number of - * categories). - * @param samplesCnt Count of samples. - * @param featuresCnt Count of features. - */ - public BiIndexedCacheColumnDecisionTreeTrainerInput(IgniteCache<BiIndex, Double> cache, - Map<Integer, Integer> catFeaturesInfo, int samplesCnt, int featuresCnt) { - super(cache, - () -> IntStream.range(0, samplesCnt).mapToObj(s -> new BiIndex(s, featuresCnt)), - e -> Stream.of(new IgniteBiTuple<>(e.getKey().row(), e.getValue())), - DoubleStream::of, - fIdx -> IntStream.range(0, samplesCnt).mapToObj(s -> new BiIndex(s, fIdx)), - catFeaturesInfo, - featuresCnt, - samplesCnt); - } - - /** {@inheritDoc} */ - @Override public Object affinityKey(int idx, Ignite ignite) { - return idx; - } -} http://git-wip-us.apache.org/repos/asf/ignite/blob/9abfee69/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/CacheColumnDecisionTreeTrainerInput.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/CacheColumnDecisionTreeTrainerInput.java b/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/CacheColumnDecisionTreeTrainerInput.java deleted file mode 100644 index 40927b7..0000000 --- a/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/CacheColumnDecisionTreeTrainerInput.java +++ /dev/null @@ -1,141 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.ignite.ml.trees.trainers.columnbased; - -import java.util.Map; -import java.util.stream.Collectors; -import java.util.stream.DoubleStream; -import java.util.stream.Stream; -import javax.cache.Cache; -import org.apache.ignite.Ignite; -import org.apache.ignite.IgniteCache; -import org.apache.ignite.Ignition; -import org.apache.ignite.internal.processors.cache.CacheEntryImpl; -import org.apache.ignite.lang.IgniteBiTuple; -import org.apache.ignite.ml.math.functions.IgniteFunction; -import org.apache.ignite.ml.math.functions.IgniteSupplier; - -/** - * Adapter of a given cache to {@link CacheColumnDecisionTreeTrainerInput} - * - * @param <K> Class of keys of the cache. - * @param <V> Class of values of the cache. - */ -public abstract class CacheColumnDecisionTreeTrainerInput<K, V> implements ColumnDecisionTreeTrainerInput { - /** Supplier of labels key. */ - private final IgniteSupplier<Stream<K>> labelsKeys; - - /** Count of features. */ - private final int featuresCnt; - - /** Function which maps feature index to Stream of keys corresponding to this feature index. */ - private final IgniteFunction<Integer, Stream<K>> keyMapper; - - /** Information about which features are categorical in form of feature index -> number of categories. */ - private final Map<Integer, Integer> catFeaturesInfo; - - /** Cache name. */ - private final String cacheName; - - /** Count of samples. */ - private final int samplesCnt; - - /** Function used for mapping cache values to stream of tuples. */ - private final IgniteFunction<Cache.Entry<K, V>, Stream<IgniteBiTuple<Integer, Double>>> valuesMapper; - - /** - * Function which map value of entry with label key to DoubleStream. - * Look at {@code CacheColumnDecisionTreeTrainerInput::labels} for understanding how {@code labelsKeys} and - * {@code labelsMapper} interact. - */ - private final IgniteFunction<V, DoubleStream> labelsMapper; - - /** - * Constructs input for {@link ColumnDecisionTreeTrainer}. - * - * @param c Cache. - * @param valuesMapper Function for mapping cache entry to stream used by {@link ColumnDecisionTreeTrainer}. - * @param labelsMapper Function used for mapping cache value to labels array. - * @param keyMapper Function used for mapping feature index to the cache key. - * @param catFeaturesInfo Information about which features are categorical in form of feature index -> number of - * categories. - * @param featuresCnt Count of features. - * @param samplesCnt Count of samples. - */ - // TODO: IGNITE-5724 think about boxing/unboxing - public CacheColumnDecisionTreeTrainerInput(IgniteCache<K, V> c, - IgniteSupplier<Stream<K>> labelsKeys, - IgniteFunction<Cache.Entry<K, V>, Stream<IgniteBiTuple<Integer, Double>>> valuesMapper, - IgniteFunction<V, DoubleStream> labelsMapper, - IgniteFunction<Integer, Stream<K>> keyMapper, - Map<Integer, Integer> catFeaturesInfo, - int featuresCnt, int samplesCnt) { - - cacheName = c.getName(); - this.labelsKeys = labelsKeys; - this.valuesMapper = valuesMapper; - this.labelsMapper = labelsMapper; - this.keyMapper = keyMapper; - this.catFeaturesInfo = catFeaturesInfo; - this.samplesCnt = samplesCnt; - this.featuresCnt = featuresCnt; - } - - /** {@inheritDoc} */ - @Override public Stream<IgniteBiTuple<Integer, Double>> values(int idx) { - return cache(Ignition.localIgnite()).getAll(keyMapper.apply(idx).collect(Collectors.toSet())). - entrySet(). - stream(). - flatMap(ent -> valuesMapper.apply(new CacheEntryImpl<>(ent.getKey(), ent.getValue()))); - } - - /** {@inheritDoc} */ - @Override public double[] labels(Ignite ignite) { - return labelsKeys.get().map(k -> get(k, ignite)).flatMapToDouble(labelsMapper).toArray(); - } - - /** {@inheritDoc} */ - @Override public Map<Integer, Integer> catFeaturesInfo() { - return catFeaturesInfo; - } - - /** {@inheritDoc} */ - @Override public int featuresCount() { - return featuresCnt; - } - - /** {@inheritDoc} */ - @Override public Object affinityKey(int idx, Ignite ignite) { - return ignite.affinity(cacheName).affinityKey(keyMapper.apply(idx)); - } - - /** */ - private V get(K k, Ignite ignite) { - V res = cache(ignite).localPeek(k); - - if (res == null) - res = cache(ignite).get(k); - - return res; - } - - /** */ - private IgniteCache<K, V> cache(Ignite ignite) { - return ignite.getOrCreateCache(cacheName); - } -}