IGNITE-5012 Implement ordinary least squares (OLS) linear regression.
Project: http://git-wip-us.apache.org/repos/asf/ignite/repo Commit: http://git-wip-us.apache.org/repos/asf/ignite/commit/934f6ac2 Tree: http://git-wip-us.apache.org/repos/asf/ignite/tree/934f6ac2 Diff: http://git-wip-us.apache.org/repos/asf/ignite/diff/934f6ac2 Branch: refs/heads/ignite-5075-cacheStart Commit: 934f6ac22c04f652815f79a9238ea72b9111a7e8 Parents: 156ec53 Author: Artem Malykh <amal...@gridgain.com> Authored: Wed May 3 20:05:18 2017 +0300 Committer: Anton Vinogradov <a...@apache.org> Committed: Wed May 3 20:05:18 2017 +0300 ---------------------------------------------------------------------- .../org/apache/ignite/ml/math/Precision.java | 588 ++++++++++++++ .../java/org/apache/ignite/ml/math/Tracer.java | 22 +- .../decompositions/CholeskyDecomposition.java | 6 +- .../math/decompositions/EigenDecomposition.java | 6 +- .../ml/math/decompositions/LUDecomposition.java | 9 +- .../ml/math/decompositions/QRDecomposition.java | 70 +- .../SingularValueDecomposition.java | 5 +- .../math/exceptions/CardinalityException.java | 6 +- .../exceptions/InsufficientDataException.java | 44 + .../exceptions/MathArithmeticException.java | 47 ++ .../MathIllegalArgumentException.java | 37 + .../math/exceptions/MathRuntimeException.java | 47 ++ .../ml/math/exceptions/NoDataException.java | 45 + .../NonPositiveDefiniteMatrixException.java | 8 +- .../exceptions/NonSquareMatrixException.java | 33 + .../math/exceptions/NullArgumentException.java | 27 + .../exceptions/SingularMatrixException.java | 9 +- .../ignite/ml/math/functions/Functions.java | 5 + .../apache/ignite/ml/math/util/MatrixUtil.java | 121 +++ .../AbstractMultipleLinearRegression.java | 358 ++++++++ .../regressions/MultipleLinearRegression.java | 71 ++ .../OLSMultipleLinearRegression.java | 272 +++++++ .../regressions/RegressionsErrorMessages.java | 28 + .../ignite/ml/regressions/package-info.java | 22 + .../java/org/apache/ignite/ml/TestUtils.java | 248 ++++++ .../apache/ignite/ml/math/ExternalizeTest.java | 1 + .../ignite/ml/math/MathImplLocalTestSuite.java | 7 +- .../CholeskyDecompositionTest.java | 6 +- .../decompositions/LUDecompositionTest.java | 6 +- .../decompositions/QRDecompositionTest.java | 6 +- .../SingularValueDecompositionTest.java | 6 +- .../AbstractMultipleLinearRegressionTest.java | 164 ++++ .../OLSMultipleLinearRegressionTest.java | 812 +++++++++++++++++++ 33 files changed, 3100 insertions(+), 42 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/ignite/blob/934f6ac2/modules/ml/src/main/java/org/apache/ignite/ml/math/Precision.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/math/Precision.java b/modules/ml/src/main/java/org/apache/ignite/ml/math/Precision.java new file mode 100644 index 0000000..830644c --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/math/Precision.java @@ -0,0 +1,588 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.math; + +import java.math.BigDecimal; +import org.apache.ignite.ml.math.exceptions.MathArithmeticException; +import org.apache.ignite.ml.math.exceptions.MathIllegalArgumentException; + +/** + * This class is based on the corresponding class from Apache Common Math lib. + * Utilities for comparing numbers. * + */ +public class Precision { + /** + * <p> + * Largest double-precision floating-point number such that + * {@code 1 + EPSILON} is numerically equal to 1. This value is an upper + * bound on the relative error due to rounding real numbers to double + * precision floating-point numbers. + * </p> + * <p> + * In IEEE 754 arithmetic, this is 2<sup>-53</sup>. + * </p> + * + * @see <a href="http://en.wikipedia.org/wiki/Machine_epsilon">Machine epsilon</a> + */ + public static final double EPSILON; + + /** + * Safe minimum, such that {@code 1 / SAFE_MIN} does not overflow. + * <br/> + * In IEEE 754 arithmetic, this is also the smallest normalized + * number 2<sup>-1022</sup>. + */ + public static final double SAFE_MIN; + + /** Exponent offset in IEEE754 representation. */ + private static final long EXPONENT_OFFSET = 1023L; + + /** Offset to order signed double numbers lexicographically. */ + private static final long SGN_MASK = 0x8000000000000000L; + /** Offset to order signed double numbers lexicographically. */ + private static final int SGN_MASK_FLOAT = 0x80000000; + /** Positive zero. */ + private static final double POSITIVE_ZERO = 0d; + /** Positive zero bits. */ + private static final long POSITIVE_ZERO_DOUBLE_BITS = Double.doubleToRawLongBits(+0.0); + /** Negative zero bits. */ + private static final long NEGATIVE_ZERO_DOUBLE_BITS = Double.doubleToRawLongBits(-0.0); + /** Positive zero bits. */ + private static final int POSITIVE_ZERO_FLOAT_BITS = Float.floatToRawIntBits(+0.0f); + /** Negative zero bits. */ + private static final int NEGATIVE_ZERO_FLOAT_BITS = Float.floatToRawIntBits(-0.0f); + /** */ + private static final String INVALID_ROUNDING_METHOD = "invalid rounding method {0}, " + + "valid methods: {1} ({2}), {3} ({4}), {5} ({6}), {7} ({8}), {9} ({10}), {11} ({12}), {13} ({14}), {15} ({16})"; + + static { + /* + * This was previously expressed as = 0x1.0p-53; + * However, OpenJDK (Sparc Solaris) cannot handle such small + * constants: MATH-721 + */ + EPSILON = Double.longBitsToDouble((EXPONENT_OFFSET - 53L) << 52); + + /* + * This was previously expressed as = 0x1.0p-1022; + * However, OpenJDK (Sparc Solaris) cannot handle such small + * constants: MATH-721 + */ + SAFE_MIN = Double.longBitsToDouble((EXPONENT_OFFSET - 1022L) << 52); + } + + /** + * Private constructor. + */ + private Precision() { + } + + /** + * Compares two numbers given some amount of allowed error. + * + * @param x the first number + * @param y the second number + * @param eps the amount of error to allow when checking for equality + * @return <ul><li>0 if {@link #equals(double, double, double) equals(x, y, eps)}</li> <li>< 0 if !{@link + * #equals(double, double, double) equals(x, y, eps)} && x < y</li> <li>> 0 if !{@link #equals(double, + * double, double) equals(x, y, eps)} && x > y or either argument is NaN</li></ul> + */ + public static int compareTo(double x, double y, double eps) { + if (equals(x, y, eps)) + return 0; + else if (x < y) + return -1; + return 1; + } + + /** + * Compares two numbers given some amount of allowed error. + * Two float numbers are considered equal if there are {@code (maxUlps - 1)} + * (or fewer) floating point numbers between them, i.e. two adjacent floating + * point numbers are considered equal. + * Adapted from <a + * href="http://randomascii.wordpress.com/2012/02/25/comparing-floating-point-numbers-2012-edition/"> + * Bruce Dawson</a>. Returns {@code false} if either of the arguments is NaN. + * + * @param x first value + * @param y second value + * @param maxUlps {@code (maxUlps - 1)} is the number of floating point values between {@code x} and {@code y}. + * @return <ul><li>0 if {@link #equals(double, double, int) equals(x, y, maxUlps)}</li> <li>< 0 if !{@link + * #equals(double, double, int) equals(x, y, maxUlps)} && x < y</li> <li>> 0 if !{@link + * #equals(double, double, int) equals(x, y, maxUlps)} && x > y or either argument is NaN</li></ul> + */ + public static int compareTo(final double x, final double y, final int maxUlps) { + if (equals(x, y, maxUlps)) + return 0; + else if (x < y) + return -1; + return 1; + } + + /** + * Returns true iff they are equal as defined by + * {@link #equals(float, float, int) equals(x, y, 1)}. + * + * @param x first value + * @param y second value + * @return {@code true} if the values are equal. + */ + public static boolean equals(float x, float y) { + return equals(x, y, 1); + } + + /** + * Returns true if both arguments are NaN or they are + * equal as defined by {@link #equals(float, float) equals(x, y, 1)}. + * + * @param x first value + * @param y second value + * @return {@code true} if the values are equal or both are NaN. + * @since 2.2 + */ + public static boolean equalsIncludingNaN(float x, float y) { + return (x != x || y != y) ? !(x != x ^ y != y) : equals(x, y, 1); + } + + /** + * Returns true if the arguments are equal or within the range of allowed + * error (inclusive). Returns {@code false} if either of the arguments + * is NaN. + * + * @param x first value + * @param y second value + * @param eps the amount of absolute error to allow. + * @return {@code true} if the values are equal or within range of each other. + * @since 2.2 + */ + public static boolean equals(float x, float y, float eps) { + return equals(x, y, 1) || Math.abs(y - x) <= eps; + } + + /** + * Returns true if the arguments are both NaN, are equal, or are within the range + * of allowed error (inclusive). + * + * @param x first value + * @param y second value + * @param eps the amount of absolute error to allow. + * @return {@code true} if the values are equal or within range of each other, or both are NaN. + * @since 2.2 + */ + public static boolean equalsIncludingNaN(float x, float y, float eps) { + return equalsIncludingNaN(x, y) || (Math.abs(y - x) <= eps); + } + + /** + * Returns true if the arguments are equal or within the range of allowed + * error (inclusive). + * Two float numbers are considered equal if there are {@code (maxUlps - 1)} + * (or fewer) floating point numbers between them, i.e. two adjacent floating + * point numbers are considered equal. + * Adapted from <a + * href="http://randomascii.wordpress.com/2012/02/25/comparing-floating-point-numbers-2012-edition/"> + * Bruce Dawson</a>. Returns {@code false} if either of the arguments is NaN. + * + * @param x first value + * @param y second value + * @param maxUlps {@code (maxUlps - 1)} is the number of floating point values between {@code x} and {@code y}. + * @return {@code true} if there are fewer than {@code maxUlps} floating point values between {@code x} and {@code + * y}. + * @since 2.2 + */ + public static boolean equals(final float x, final float y, final int maxUlps) { + + final int xInt = Float.floatToRawIntBits(x); + final int yInt = Float.floatToRawIntBits(y); + + final boolean isEqual; + if (((xInt ^ yInt) & SGN_MASK_FLOAT) == 0) { + // number have same sign, there is no risk of overflow + isEqual = Math.abs(xInt - yInt) <= maxUlps; + } + else { + // number have opposite signs, take care of overflow + final int deltaPlus; + final int deltaMinus; + if (xInt < yInt) { + deltaPlus = yInt - POSITIVE_ZERO_FLOAT_BITS; + deltaMinus = xInt - NEGATIVE_ZERO_FLOAT_BITS; + } + else { + deltaPlus = xInt - POSITIVE_ZERO_FLOAT_BITS; + deltaMinus = yInt - NEGATIVE_ZERO_FLOAT_BITS; + } + + if (deltaPlus > maxUlps) + isEqual = false; + else + isEqual = deltaMinus <= (maxUlps - deltaPlus); + + } + + return isEqual && !Float.isNaN(x) && !Float.isNaN(y); + + } + + /** + * Returns true if the arguments are both NaN or if they are equal as defined + * by {@link #equals(float, float, int) equals(x, y, maxUlps)}. + * + * @param x first value + * @param y second value + * @param maxUlps {@code (maxUlps - 1)} is the number of floating point values between {@code x} and {@code y}. + * @return {@code true} if both arguments are NaN or if there are less than {@code maxUlps} floating point values + * between {@code x} and {@code y}. + * @since 2.2 + */ + public static boolean equalsIncludingNaN(float x, float y, int maxUlps) { + return (x != x || y != y) ? !(x != x ^ y != y) : equals(x, y, maxUlps); + } + + /** + * Returns true iff they are equal as defined by + * {@link #equals(double, double, int) equals(x, y, 1)}. + * + * @param x first value + * @param y second value + * @return {@code true} if the values are equal. + */ + public static boolean equals(double x, double y) { + return equals(x, y, 1); + } + + /** + * Returns true if the arguments are both NaN or they are + * equal as defined by {@link #equals(double, double) equals(x, y, 1)}. + * + * @param x first value + * @param y second value + * @return {@code true} if the values are equal or both are NaN. + * @since 2.2 + */ + public static boolean equalsIncludingNaN(double x, double y) { + return (x != x || y != y) ? !(x != x ^ y != y) : equals(x, y, 1); + } + + /** + * Returns {@code true} if there is no double value strictly between the + * arguments or the difference between them is within the range of allowed + * error (inclusive). Returns {@code false} if either of the arguments + * is NaN. + * + * @param x First value. + * @param y Second value. + * @param eps Amount of allowed absolute error. + * @return {@code true} if the values are two adjacent floating point numbers or they are within range of each + * other. + */ + public static boolean equals(double x, double y, double eps) { + return equals(x, y, 1) || Math.abs(y - x) <= eps; + } + + /** + * Returns {@code true} if there is no double value strictly between the + * arguments or the relative difference between them is less than or equal + * to the given tolerance. Returns {@code false} if either of the arguments + * is NaN. + * + * @param x First value. + * @param y Second value. + * @param eps Amount of allowed relative error. + * @return {@code true} if the values are two adjacent floating point numbers or they are within range of each + * other. + * @since 3.1 + */ + public static boolean equalsWithRelativeTolerance(double x, double y, double eps) { + if (equals(x, y, 1)) + return true; + + final double absMax = Math.max(Math.abs(x), Math.abs(y)); + final double relativeDifference = Math.abs((x - y) / absMax); + + return relativeDifference <= eps; + } + + /** + * Returns true if the arguments are both NaN, are equal or are within the range + * of allowed error (inclusive). + * + * @param x first value + * @param y second value + * @param eps the amount of absolute error to allow. + * @return {@code true} if the values are equal or within range of each other, or both are NaN. + * @since 2.2 + */ + public static boolean equalsIncludingNaN(double x, double y, double eps) { + return equalsIncludingNaN(x, y) || (Math.abs(y - x) <= eps); + } + + /** + * Returns true if the arguments are equal or within the range of allowed + * error (inclusive). + * <p> + * Two float numbers are considered equal if there are {@code (maxUlps - 1)} + * (or fewer) floating point numbers between them, i.e. two adjacent + * floating point numbers are considered equal. + * </p> + * <p> + * Adapted from <a + * href="http://randomascii.wordpress.com/2012/02/25/comparing-floating-point-numbers-2012-edition/"> + * Bruce Dawson</a>. Returns {@code false} if either of the arguments is NaN. + * </p> + * + * @param x first value + * @param y second value + * @param maxUlps {@code (maxUlps - 1)} is the number of floating point values between {@code x} and {@code y}. + * @return {@code true} if there are fewer than {@code maxUlps} floating point values between {@code x} and {@code + * y}. + */ + public static boolean equals(final double x, final double y, final int maxUlps) { + + final long xInt = Double.doubleToRawLongBits(x); + final long yInt = Double.doubleToRawLongBits(y); + + final boolean isEqual; + if (((xInt ^ yInt) & SGN_MASK) == 0L) { + // number have same sign, there is no risk of overflow + isEqual = Math.abs(xInt - yInt) <= maxUlps; + } + else { + // number have opposite signs, take care of overflow + final long deltaPlus; + final long deltaMinus; + if (xInt < yInt) { + deltaPlus = yInt - POSITIVE_ZERO_DOUBLE_BITS; + deltaMinus = xInt - NEGATIVE_ZERO_DOUBLE_BITS; + } + else { + deltaPlus = xInt - POSITIVE_ZERO_DOUBLE_BITS; + deltaMinus = yInt - NEGATIVE_ZERO_DOUBLE_BITS; + } + + if (deltaPlus > maxUlps) + isEqual = false; + else + isEqual = deltaMinus <= (maxUlps - deltaPlus); + + } + + return isEqual && !Double.isNaN(x) && !Double.isNaN(y); + + } + + /** + * Returns true if both arguments are NaN or if they are equal as defined + * by {@link #equals(double, double, int) equals(x, y, maxUlps)}. + * + * @param x first value + * @param y second value + * @param maxUlps {@code (maxUlps - 1)} is the number of floating point values between {@code x} and {@code y}. + * @return {@code true} if both arguments are NaN or if there are less than {@code maxUlps} floating point values + * between {@code x} and {@code y}. + * @since 2.2 + */ + public static boolean equalsIncludingNaN(double x, double y, int maxUlps) { + return (x != x || y != y) ? !(x != x ^ y != y) : equals(x, y, maxUlps); + } + + /** + * Rounds the given value to the specified number of decimal places. + * The value is rounded using the {@link BigDecimal#ROUND_HALF_UP} method. + * + * @param x Value to round. + * @param scale Number of digits to the right of the decimal point. + * @return the rounded value. + * @since 1.1 (previously in {@code MathUtils}, moved as of version 3.0) + */ + public static double round(double x, int scale) { + return round(x, scale, BigDecimal.ROUND_HALF_UP); + } + + /** + * Rounds the given value to the specified number of decimal places. + * The value is rounded using the given method which is any method defined + * in {@link BigDecimal}. + * If {@code x} is infinite or {@code NaN}, then the value of {@code x} is + * returned unchanged, regardless of the other parameters. + * + * @param x Value to round. + * @param scale Number of digits to the right of the decimal point. + * @param roundingMtd Rounding method as defined in {@link BigDecimal}. + * @return the rounded value. + * @throws ArithmeticException if {@code roundingMethod == ROUND_UNNECESSARY} and the specified scaling operation + * would require rounding. + * @throws IllegalArgumentException if {@code roundingMethod} does not represent a valid rounding mode. + * @since 1.1 (previously in {@code MathUtils}, moved as of version 3.0) + */ + public static double round(double x, int scale, int roundingMtd) { + try { + final double rounded = (new BigDecimal(Double.toString(x)) + .setScale(scale, roundingMtd)) + .doubleValue(); + // MATH-1089: negative values rounded to zero should result in negative zero + return rounded == POSITIVE_ZERO ? POSITIVE_ZERO * x : rounded; + } + catch (NumberFormatException ex) { + if (Double.isInfinite(x)) + return x; + else + return Double.NaN; + } + } + + /** + * Rounds the given value to the specified number of decimal places. + * The value is rounded using the {@link BigDecimal#ROUND_HALF_UP} method. + * + * @param x Value to round. + * @param scale Number of digits to the right of the decimal point. + * @return the rounded value. + * @since 1.1 (previously in {@code MathUtils}, moved as of version 3.0) + */ + public static float round(float x, int scale) { + return round(x, scale, BigDecimal.ROUND_HALF_UP); + } + + /** + * Rounds the given value to the specified number of decimal places. + * The value is rounded using the given method which is any method defined + * in {@link BigDecimal}. + * + * @param x Value to round. + * @param scale Number of digits to the right of the decimal point. + * @param roundingMtd Rounding method as defined in {@link BigDecimal}. + * @return the rounded value. + * @throws MathArithmeticException if an exact operation is required but result is not exact + * @throws MathIllegalArgumentException if {@code roundingMethod} is not a valid rounding method. + * @since 1.1 (previously in {@code MathUtils}, moved as of version 3.0) + */ + public static float round(float x, int scale, int roundingMtd) + throws MathArithmeticException, MathIllegalArgumentException { + final float sign = Math.copySign(1f, x); + final float factor = (float)Math.pow(10.0f, scale) * sign; + return (float)roundUnscaled(x * factor, sign, roundingMtd) / factor; + } + + /** + * Rounds the given non-negative value to the "nearest" integer. Nearest is + * determined by the rounding method specified. Rounding methods are defined + * in {@link BigDecimal}. + * + * @param unscaled Value to round. + * @param sign Sign of the original, scaled value. + * @param roundingMtd Rounding method, as defined in {@link BigDecimal}. + * @return the rounded value. + * @throws MathArithmeticException if an exact operation is required but result is not exact + * @throws MathIllegalArgumentException if {@code roundingMethod} is not a valid rounding method. + * @since 1.1 (previously in {@code MathUtils}, moved as of version 3.0) + */ + private static double roundUnscaled(double unscaled, double sign, int roundingMtd) + throws MathArithmeticException, MathIllegalArgumentException { + switch (roundingMtd) { + case BigDecimal.ROUND_CEILING: + if (sign == -1) + unscaled = Math.floor(Math.nextAfter(unscaled, Double.NEGATIVE_INFINITY)); + else + unscaled = Math.ceil(Math.nextAfter(unscaled, Double.POSITIVE_INFINITY)); + break; + case BigDecimal.ROUND_DOWN: + unscaled = Math.floor(Math.nextAfter(unscaled, Double.NEGATIVE_INFINITY)); + break; + case BigDecimal.ROUND_FLOOR: + if (sign == -1) + unscaled = Math.ceil(Math.nextAfter(unscaled, Double.POSITIVE_INFINITY)); + else + unscaled = Math.floor(Math.nextAfter(unscaled, Double.NEGATIVE_INFINITY)); + break; + case BigDecimal.ROUND_HALF_DOWN: { + unscaled = Math.nextAfter(unscaled, Double.NEGATIVE_INFINITY); + double fraction = unscaled - Math.floor(unscaled); + if (fraction > 0.5) + unscaled = Math.ceil(unscaled); + else + unscaled = Math.floor(unscaled); + break; + } + case BigDecimal.ROUND_HALF_EVEN: { + double fraction = unscaled - Math.floor(unscaled); + if (fraction > 0.5) + unscaled = Math.ceil(unscaled); + else if (fraction < 0.5) + unscaled = Math.floor(unscaled); + else { + // The following equality test is intentional and needed for rounding purposes + if (Math.floor(unscaled) / 2.0 == Math.floor(Math.floor(unscaled) / 2.0)) { // even + unscaled = Math.floor(unscaled); + } + else { // odd + unscaled = Math.ceil(unscaled); + } + } + break; + } + case BigDecimal.ROUND_HALF_UP: { + unscaled = Math.nextAfter(unscaled, Double.POSITIVE_INFINITY); + double fraction = unscaled - Math.floor(unscaled); + if (fraction >= 0.5) + unscaled = Math.ceil(unscaled); + else + unscaled = Math.floor(unscaled); + break; + } + case BigDecimal.ROUND_UNNECESSARY: + if (unscaled != Math.floor(unscaled)) + throw new MathArithmeticException(); + break; + case BigDecimal.ROUND_UP: + // do not round if the discarded fraction is equal to zero + if (unscaled != Math.floor(unscaled)) + unscaled = Math.ceil(Math.nextAfter(unscaled, Double.POSITIVE_INFINITY)); + break; + default: + throw new MathIllegalArgumentException(INVALID_ROUNDING_METHOD, + roundingMtd, + "ROUND_CEILING", BigDecimal.ROUND_CEILING, + "ROUND_DOWN", BigDecimal.ROUND_DOWN, + "ROUND_FLOOR", BigDecimal.ROUND_FLOOR, + "ROUND_HALF_DOWN", BigDecimal.ROUND_HALF_DOWN, + "ROUND_HALF_EVEN", BigDecimal.ROUND_HALF_EVEN, + "ROUND_HALF_UP", BigDecimal.ROUND_HALF_UP, + "ROUND_UNNECESSARY", BigDecimal.ROUND_UNNECESSARY, + "ROUND_UP", BigDecimal.ROUND_UP); + } + return unscaled; + } + + /** + * Computes a number {@code delta} close to {@code originalDelta} with + * the property that <pre><code> + * x + delta - x + * </code></pre> + * is exactly machine-representable. + * This is useful when computing numerical derivatives, in order to reduce + * roundoff errors. + * + * @param x Value. + * @param originalDelta Offset value. + * @return a number {@code delta} so that {@code x + delta} and {@code x} differ by a representable floating number. + */ + public static double representableDelta(double x, double originalDelta) { + return x + originalDelta - x; + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/934f6ac2/modules/ml/src/main/java/org/apache/ignite/ml/math/Tracer.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/math/Tracer.java b/modules/ml/src/main/java/org/apache/ignite/ml/math/Tracer.java index d334575..d343ce8 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/math/Tracer.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/math/Tracer.java @@ -58,9 +58,9 @@ public class Tracer { return new ColorMapper() { /** {@inheritDoc} */ @Override public Color apply(Double d) { - int r = (int) Math.round(255 * d); + int r = (int)Math.round(255 * d); int g = 0; - int b = (int) Math.round(255 * (1 - d)); + int b = (int)Math.round(255 * (1 - d)); return new Color(r, g, b); } @@ -195,8 +195,8 @@ public class Tracer { /** * Saves given vector as CSV file. * - * @param vec Vector to save. - * @param fmt Format to use. + * @param vec Vector to save. + * @param fmt Format to use. * @param filePath Path of the file to save to. */ public static void saveAsCsv(Vector vec, String fmt, String filePath) throws IOException { @@ -208,8 +208,8 @@ public class Tracer { /** * Saves given matrix as CSV file. * - * @param mtx Matrix to save. - * @param fmt Format to use. + * @param mtx Matrix to save. + * @param fmt Format to use. * @param filePath Path of the file to save to. */ public static void saveAsCsv(Matrix mtx, String fmt, String filePath) throws IOException { @@ -232,7 +232,7 @@ public class Tracer { * Shows given matrix in the browser with D3-based visualization. * * @param mtx Matrix to show. - * @param cm Optional color mapper. If not provided - red-to-blue (R_B) mapper will be used. + * @param cm Optional color mapper. If not provided - red-to-blue (R_B) mapper will be used. * @throws IOException Thrown in case of any errors. */ public static void showHtml(Matrix mtx, ColorMapper cm) throws IOException { @@ -263,7 +263,7 @@ public class Tracer { } /** - * @param d Value of {@link Matrix} or {@link Vector} element. + * @param d Value of {@link Matrix} or {@link Vector} element. * @param clr {@link Color} to paint. * @return JSON representation for given value and color. */ @@ -280,7 +280,7 @@ public class Tracer { * Shows given vector in the browser with D3-based visualization. * * @param vec Vector to show. - * @param cm Optional color mapper. If not provided - red-to-blue (R_B) mapper will be used. + * @param cm Optional color mapper. If not provided - red-to-blue (R_B) mapper will be used. * @throws IOException Thrown in case of any errors. */ public static void showHtml(Vector vec, ColorMapper cm) throws IOException { @@ -366,7 +366,7 @@ public class Tracer { * Gets JavaScript array presentation of this vector. * * @param vec Vector to JavaScript-ify. - * @param cm Color mapper to user. + * @param cm Color mapper to user. */ private static String mkJsArrayString(Vector vec, ColorMapper cm) { boolean first = true; @@ -393,7 +393,7 @@ public class Tracer { * Gets JavaScript array presentation of this vector. * * @param mtx Matrix to JavaScript-ify. - * @param cm Color mapper to user. + * @param cm Color mapper to user. */ private static String mkJsArrayString(Matrix mtx, ColorMapper cm) { boolean first = true; http://git-wip-us.apache.org/repos/asf/ignite/blob/934f6ac2/modules/ml/src/main/java/org/apache/ignite/ml/math/decompositions/CholeskyDecomposition.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/math/decompositions/CholeskyDecomposition.java b/modules/ml/src/main/java/org/apache/ignite/ml/math/decompositions/CholeskyDecomposition.java index 6053e1c..84028fe 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/math/decompositions/CholeskyDecomposition.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/math/decompositions/CholeskyDecomposition.java @@ -17,12 +17,16 @@ package org.apache.ignite.ml.math.decompositions; +import org.apache.ignite.ml.math.Destroyable; import org.apache.ignite.ml.math.Matrix; import org.apache.ignite.ml.math.Vector; import org.apache.ignite.ml.math.exceptions.CardinalityException; import org.apache.ignite.ml.math.exceptions.NonPositiveDefiniteMatrixException; import org.apache.ignite.ml.math.exceptions.NonSymmetricMatrixException; +import static org.apache.ignite.ml.math.util.MatrixUtil.like; +import static org.apache.ignite.ml.math.util.MatrixUtil.likeVector; + /** * Calculates the Cholesky decomposition of a matrix. * <p> @@ -31,7 +35,7 @@ import org.apache.ignite.ml.math.exceptions.NonSymmetricMatrixException; * @see <a href="http://mathworld.wolfram.com/CholeskyDecomposition.html">MathWorld</a> * @see <a href="http://en.wikipedia.org/wiki/Cholesky_decomposition">Wikipedia</a> */ -public class CholeskyDecomposition extends DecompositionSupport { +public class CholeskyDecomposition implements Destroyable { /** * Default threshold above which off-diagonal elements are considered too different * and matrix not symmetric. http://git-wip-us.apache.org/repos/asf/ignite/blob/934f6ac2/modules/ml/src/main/java/org/apache/ignite/ml/math/decompositions/EigenDecomposition.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/math/decompositions/EigenDecomposition.java b/modules/ml/src/main/java/org/apache/ignite/ml/math/decompositions/EigenDecomposition.java index 698cbef..d0e91a5 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/math/decompositions/EigenDecomposition.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/math/decompositions/EigenDecomposition.java @@ -17,17 +17,21 @@ package org.apache.ignite.ml.math.decompositions; +import org.apache.ignite.ml.math.Destroyable; import org.apache.ignite.ml.math.Matrix; import org.apache.ignite.ml.math.Vector; import org.apache.ignite.ml.math.functions.Functions; +import static org.apache.ignite.ml.math.util.MatrixUtil.like; +import static org.apache.ignite.ml.math.util.MatrixUtil.likeVector; + /** * This class provides EigenDecomposition of given matrix. The class is based on * class with similar name from <a href="http://mahout.apache.org/">Apache Mahout</a> library. * * @see <a href=http://mathworld.wolfram.com/EigenDecomposition.html>MathWorld</a> */ -public class EigenDecomposition extends DecompositionSupport { +public class EigenDecomposition implements Destroyable { /** Row and column dimension (square matrix). */ private final int n; http://git-wip-us.apache.org/repos/asf/ignite/blob/934f6ac2/modules/ml/src/main/java/org/apache/ignite/ml/math/decompositions/LUDecomposition.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/math/decompositions/LUDecomposition.java b/modules/ml/src/main/java/org/apache/ignite/ml/math/decompositions/LUDecomposition.java index 02a3123..4c388b3 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/math/decompositions/LUDecomposition.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/math/decompositions/LUDecomposition.java @@ -17,11 +17,16 @@ package org.apache.ignite.ml.math.decompositions; +import org.apache.ignite.ml.math.Destroyable; import org.apache.ignite.ml.math.Matrix; import org.apache.ignite.ml.math.Vector; import org.apache.ignite.ml.math.exceptions.CardinalityException; import org.apache.ignite.ml.math.exceptions.SingularMatrixException; +import static org.apache.ignite.ml.math.util.MatrixUtil.copy; +import static org.apache.ignite.ml.math.util.MatrixUtil.like; +import static org.apache.ignite.ml.math.util.MatrixUtil.likeVector; + /** * Calculates the LU-decomposition of a square matrix. * <p> @@ -29,8 +34,10 @@ import org.apache.ignite.ml.math.exceptions.SingularMatrixException; * * @see <a href="http://mathworld.wolfram.com/LUDecomposition.html">MathWorld</a> * @see <a href="http://en.wikipedia.org/wiki/LU_decomposition">Wikipedia</a> + * + * TODO: Maybe we should make this class (and other decompositions) Externalizable. */ -public class LUDecomposition extends DecompositionSupport { +public class LUDecomposition implements Destroyable { /** Default bound to determine effective singularity in LU decomposition. */ private static final double DEFAULT_TOO_SMALL = 1e-11; http://git-wip-us.apache.org/repos/asf/ignite/blob/934f6ac2/modules/ml/src/main/java/org/apache/ignite/ml/math/decompositions/QRDecomposition.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/math/decompositions/QRDecomposition.java b/modules/ml/src/main/java/org/apache/ignite/ml/math/decompositions/QRDecomposition.java index 39215e8..5ffa574 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/math/decompositions/QRDecomposition.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/math/decompositions/QRDecomposition.java @@ -17,16 +17,21 @@ package org.apache.ignite.ml.math.decompositions; +import org.apache.ignite.ml.math.Destroyable; import org.apache.ignite.ml.math.Matrix; import org.apache.ignite.ml.math.Vector; +import org.apache.ignite.ml.math.exceptions.SingularMatrixException; import org.apache.ignite.ml.math.functions.Functions; +import static org.apache.ignite.ml.math.util.MatrixUtil.copy; +import static org.apache.ignite.ml.math.util.MatrixUtil.like; + /** * For an {@code m x n} matrix {@code A} with {@code m >= n}, the QR decomposition * is an {@code m x n} orthogonal matrix {@code Q} and an {@code n x n} upper * triangular matrix {@code R} so that {@code A = Q*R}. */ -public class QRDecomposition extends DecompositionSupport { +public class QRDecomposition implements Destroyable { /** */ private final Matrix q; /** */ @@ -41,6 +46,8 @@ public class QRDecomposition extends DecompositionSupport { private final int rows; /** */ private final int cols; + /** */ + private double threshold; /** * @param v Value to be checked for being an ordinary double. @@ -52,10 +59,21 @@ public class QRDecomposition extends DecompositionSupport { /** * Constructs a new QR decomposition object computed by Householder reflections. + * Threshold for singularity check used in this case is 0. * * @param mtx A rectangular matrix. */ public QRDecomposition(Matrix mtx) { + this(mtx, 0.0); + } + + /** + * Constructs a new QR decomposition object computed by Householder reflections. + * + * @param mtx A rectangular matrix. + * @param threshold Value used for detecting singularity of {@code R} matrix in decomposition. + */ + public QRDecomposition(Matrix mtx, double threshold) { assert mtx != null; rows = mtx.rowSize(); @@ -71,6 +89,7 @@ public class QRDecomposition extends DecompositionSupport { boolean fullRank = true; r = like(mtx, min, cols); + this.threshold = threshold; for (int i = 0; i < min; i++) { Vector qi = qTmp.viewColumn(i); @@ -155,18 +174,20 @@ public class QRDecomposition extends DecompositionSupport { throw new IllegalArgumentException("Matrix row dimensions must agree."); int cols = mtx.columnSize(); - + Matrix r = getR(); + checkSingular(r, threshold, true); Matrix x = like(mType, this.cols, cols); Matrix qt = getQ().transpose(); Matrix y = qt.times(mtx); - Matrix r = getR(); - - for (int k = Math.min(this.cols, rows) - 1; k > 0; k--) { + for (int k = Math.min(this.cols, rows) - 1; k >= 0; k--) { // X[k,] = Y[k,] / R[k,k], note that X[k,] starts with 0 so += is same as = x.viewRow(k).map(y.viewRow(k), Functions.plusMult(1 / r.get(k, k))); + if (k == 0) + continue; + // Y[0:(k-1),] -= R[0:(k-1),k] * X[k,] Vector rCol = r.viewColumn(k).viewPart(0, k); @@ -178,9 +199,48 @@ public class QRDecomposition extends DecompositionSupport { } /** + * Least squares solution of {@code A*X = B}; {@code returns X}. + * + * @param vec A vector with as many rows as {@code A}. + * @return {@code X<} that minimizes the two norm of {@code Q*R*X - B}. + * @throws IllegalArgumentException if {@code B.rows() != A.rows()}. + */ + public Vector solve(Vector vec) { + Matrix res = solve(vec.likeMatrix(vec.size(), 1).assignColumn(0, vec)); + return vec.like(res.rowSize()).assign(res.viewColumn(0)); + } + + /** * Returns a rough string rendition of a QR. */ @Override public String toString() { return String.format("QR(%d x %d, fullRank=%s)", rows, cols, hasFullRank()); } + + /** + * Check singularity. + * + * @param r R matrix. + * @param min Singularity threshold. + * @param raise Whether to raise a {@link SingularMatrixException} if any element of the diagonal fails the check. + * @return {@code true} if any element of the diagonal is smaller or equal to {@code min}. + * @throws SingularMatrixException if the matrix is singular and {@code raise} is {@code true}. + */ + private static boolean checkSingular(Matrix r, double min, boolean raise) { + // TODO: Not a very fast approach for distributed matrices. would be nice if we could independently check + // parts on different nodes for singularity and do fold with 'or'. + + final int len = r.columnSize(); + for (int i = 0; i < len; i++) { + final double d = r.getX(i, i); + if (Math.abs(d) <= min) + if (raise) + throw new SingularMatrixException("Number is too small (%f, while " + + "threshold is %f). Index of diagonal element is (%d, %d)", d, min, i, i); + else + return true; + + } + return false; + } } http://git-wip-us.apache.org/repos/asf/ignite/blob/934f6ac2/modules/ml/src/main/java/org/apache/ignite/ml/math/decompositions/SingularValueDecomposition.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/math/decompositions/SingularValueDecomposition.java b/modules/ml/src/main/java/org/apache/ignite/ml/math/decompositions/SingularValueDecomposition.java index 1b04e4f..68aeb6d 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/math/decompositions/SingularValueDecomposition.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/math/decompositions/SingularValueDecomposition.java @@ -18,8 +18,11 @@ package org.apache.ignite.ml.math.decompositions; import org.apache.ignite.ml.math.Algebra; +import org.apache.ignite.ml.math.Destroyable; import org.apache.ignite.ml.math.Matrix; +import static org.apache.ignite.ml.math.util.MatrixUtil.like; + /** * Compute a singular value decomposition (SVD) of {@code (l x k)} matrix {@code m}. * <p>This decomposition can be thought @@ -33,7 +36,7 @@ import org.apache.ignite.ml.math.Matrix; * <p>See also: <a href="https://en.wikipedia.org/wiki/Singular_value_decomposition">Wikipedia article on SVD</a>.</p> * <p>Note: complex case is currently not supported.</p> */ -public class SingularValueDecomposition extends DecompositionSupport { +public class SingularValueDecomposition implements Destroyable { // U and V. /** */ private final double[][] u; http://git-wip-us.apache.org/repos/asf/ignite/blob/934f6ac2/modules/ml/src/main/java/org/apache/ignite/ml/math/exceptions/CardinalityException.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/math/exceptions/CardinalityException.java b/modules/ml/src/main/java/org/apache/ignite/ml/math/exceptions/CardinalityException.java index f03e5d8..e8a073d 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/math/exceptions/CardinalityException.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/math/exceptions/CardinalityException.java @@ -17,12 +17,10 @@ package org.apache.ignite.ml.math.exceptions; -import org.apache.ignite.IgniteException; - /** * Indicates a cardinality mismatch in matrix or vector operations. */ -public class CardinalityException extends IgniteException { +public class CardinalityException extends MathIllegalArgumentException { /** */ private static final long serialVersionUID = 0L; @@ -33,6 +31,6 @@ public class CardinalityException extends IgniteException { * @param act Actual cardinality. */ public CardinalityException(int exp, int act) { - super("Cardinality violation [expected=" + exp + ", actual=" + act + "]"); + super("Cardinality violation [expected=%d, actual=%d]", exp, act); } } http://git-wip-us.apache.org/repos/asf/ignite/blob/934f6ac2/modules/ml/src/main/java/org/apache/ignite/ml/math/exceptions/InsufficientDataException.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/math/exceptions/InsufficientDataException.java b/modules/ml/src/main/java/org/apache/ignite/ml/math/exceptions/InsufficientDataException.java new file mode 100644 index 0000000..a57997d --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/math/exceptions/InsufficientDataException.java @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.math.exceptions; + +/** + * This class is based on the corresponding class from Apache Common Math lib. + * Exception to be thrown when there is insufficient data to perform a computation. + */ +public class InsufficientDataException extends MathIllegalArgumentException { + /** Serializable version Id. */ + private static final long serialVersionUID = -2629324471511903359L; + + /** */ + private static final String INSUFFICIENT_DATA = "Insufficient data."; + + /** + * Construct the exception. + */ + public InsufficientDataException() { + this(INSUFFICIENT_DATA); + } + + /** + * Construct the exception. + */ + public InsufficientDataException(String msg, Object... args) { + super(msg, args); + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/934f6ac2/modules/ml/src/main/java/org/apache/ignite/ml/math/exceptions/MathArithmeticException.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/math/exceptions/MathArithmeticException.java b/modules/ml/src/main/java/org/apache/ignite/ml/math/exceptions/MathArithmeticException.java new file mode 100644 index 0000000..f48f3c5 --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/math/exceptions/MathArithmeticException.java @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.ignite.ml.math.exceptions; + +/** + * This class is based on the corresponding class from Apache Common Math lib. + * Base class for arithmetic exceptions. + * It is used for all the exceptions that have the semantics of the standard + * {@link ArithmeticException}, but must also provide a localized + * message. + */ +public class MathArithmeticException extends MathRuntimeException { + /** Serializable version Id. */ + private static final long serialVersionUID = -6024911025449780478L; + + /** + * Default constructor. + */ + public MathArithmeticException() { + this("arithmetic exception"); + } + + /** + * Constructor with a specific message. + * + * @param format Message pattern providing the specific context of the error. + * @param args Arguments. + */ + public MathArithmeticException(String format, Object... args) { + super(format, args); + } + +} http://git-wip-us.apache.org/repos/asf/ignite/blob/934f6ac2/modules/ml/src/main/java/org/apache/ignite/ml/math/exceptions/MathIllegalArgumentException.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/math/exceptions/MathIllegalArgumentException.java b/modules/ml/src/main/java/org/apache/ignite/ml/math/exceptions/MathIllegalArgumentException.java new file mode 100644 index 0000000..eac685d --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/math/exceptions/MathIllegalArgumentException.java @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.ignite.ml.math.exceptions; + +/** + * Base class for all preconditions violation exceptions. + * In most cases, this class should not be instantiated directly: it should + * serve as a base class to create all the exceptions that have the semantics + * of the standard {@link IllegalArgumentException}. + */ +public class MathIllegalArgumentException extends MathRuntimeException { + /** Serializable version Id. */ + private static final long serialVersionUID = -6024911025449780478L; + + /** + * @param format Message format string explaining the cause of the error. + * @param args Arguments. + */ + public MathIllegalArgumentException(String format, Object... args) { + super(String.format(format, args)); + } + +} http://git-wip-us.apache.org/repos/asf/ignite/blob/934f6ac2/modules/ml/src/main/java/org/apache/ignite/ml/math/exceptions/MathRuntimeException.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/math/exceptions/MathRuntimeException.java b/modules/ml/src/main/java/org/apache/ignite/ml/math/exceptions/MathRuntimeException.java new file mode 100644 index 0000000..865428e --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/math/exceptions/MathRuntimeException.java @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.ignite.ml.math.exceptions; + +import org.apache.ignite.IgniteException; + +/** + * This class is based on the corresponding class from Apache Common Math lib. + * In most cases, this class should not be instantiated directly: it should + * serve as a base class for implementing exception classes that describe a + * specific "problem". + */ +public class MathRuntimeException extends IgniteException { + /** Serializable version Id. */ + private static final long serialVersionUID = 20120926L; + + /** + * @param format Message pattern explaining the cause of the error. + * @param args Arguments. + */ + public MathRuntimeException(String format, Object... args) { + this(null, format, args); + } + + /** + * @param cause Root cause. + * @param format Message pattern explaining the cause of the error. + * @param args Arguments. + */ + public MathRuntimeException(Throwable cause, String format, Object... args) { + super(String.format(format, args), cause); + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/934f6ac2/modules/ml/src/main/java/org/apache/ignite/ml/math/exceptions/NoDataException.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/math/exceptions/NoDataException.java b/modules/ml/src/main/java/org/apache/ignite/ml/math/exceptions/NoDataException.java new file mode 100644 index 0000000..46d64aa --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/math/exceptions/NoDataException.java @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.ignite.ml.math.exceptions; + +/** + * This class is based on the corresponding class from Apache Common Math lib. + * Exception to be thrown when the required data is missing. + */ +public class NoDataException extends MathIllegalArgumentException { + /** Serializable version Id. */ + private static final long serialVersionUID = -3629324471511904459L; + + /** */ + private static final String NO_DATA = "No data."; + + /** + * Construct the exception. + */ + public NoDataException() { + this(NO_DATA); + } + + /** + * Construct the exception with a specific message. + * + * @param msg Message. + */ + public NoDataException(String msg) { + super(msg); + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/934f6ac2/modules/ml/src/main/java/org/apache/ignite/ml/math/exceptions/NonPositiveDefiniteMatrixException.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/math/exceptions/NonPositiveDefiniteMatrixException.java b/modules/ml/src/main/java/org/apache/ignite/ml/math/exceptions/NonPositiveDefiniteMatrixException.java index b0cf294..2e588dc 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/math/exceptions/NonPositiveDefiniteMatrixException.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/math/exceptions/NonPositiveDefiniteMatrixException.java @@ -17,12 +17,10 @@ package org.apache.ignite.ml.math.exceptions; -import org.apache.ignite.IgniteException; - /** * This exception is used to indicate error condition of matrix elements failing the positivity check. */ -public class NonPositiveDefiniteMatrixException extends IgniteException { +public class NonPositiveDefiniteMatrixException extends MathIllegalArgumentException { /** * Construct an exception. * @@ -31,7 +29,7 @@ public class NonPositiveDefiniteMatrixException extends IgniteException { * @param threshold Absolute positivity threshold. */ public NonPositiveDefiniteMatrixException(double wrong, int idx, double threshold) { - super("Matrix must be positive, wrong element located on diagonal with index " - + idx + " and has value " + wrong + " with this threshold " + threshold); + super("Matrix must be positive, wrong element located on diagonal with index %d and has value %f with this threshold %f", + idx, wrong, threshold); } } http://git-wip-us.apache.org/repos/asf/ignite/blob/934f6ac2/modules/ml/src/main/java/org/apache/ignite/ml/math/exceptions/NonSquareMatrixException.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/math/exceptions/NonSquareMatrixException.java b/modules/ml/src/main/java/org/apache/ignite/ml/math/exceptions/NonSquareMatrixException.java new file mode 100644 index 0000000..5a4c207 --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/math/exceptions/NonSquareMatrixException.java @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.math.exceptions; + +/** + * Indicates that given matrix is not a square matrix. + */ +public class NonSquareMatrixException extends CardinalityException { + /** + * Creates new square size violation exception. + * + * @param exp Expected cardinality. + * @param act Actual cardinality. + */ + public NonSquareMatrixException(int exp, int act) { + super(exp, act); + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/934f6ac2/modules/ml/src/main/java/org/apache/ignite/ml/math/exceptions/NullArgumentException.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/math/exceptions/NullArgumentException.java b/modules/ml/src/main/java/org/apache/ignite/ml/math/exceptions/NullArgumentException.java new file mode 100644 index 0000000..58a6fa3 --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/math/exceptions/NullArgumentException.java @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.ignite.ml.math.exceptions; + +/** + * This class is based on the corresponding class from Apache Common Math lib. + * All conditions checks that fail due to a {@code null} argument must throw + * this exception. + * This class is meant to signal a precondition violation ("null is an illegal + * argument"). + */ +public class NullArgumentException extends NullPointerException { +} http://git-wip-us.apache.org/repos/asf/ignite/blob/934f6ac2/modules/ml/src/main/java/org/apache/ignite/ml/math/exceptions/SingularMatrixException.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/math/exceptions/SingularMatrixException.java b/modules/ml/src/main/java/org/apache/ignite/ml/math/exceptions/SingularMatrixException.java index 789b686..c7acc80 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/math/exceptions/SingularMatrixException.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/math/exceptions/SingularMatrixException.java @@ -17,14 +17,17 @@ package org.apache.ignite.ml.math.exceptions; -import org.apache.ignite.IgniteException; - /** * Exception to be thrown when a non-singular matrix is expected. */ -public class SingularMatrixException extends IgniteException { +public class SingularMatrixException extends MathIllegalArgumentException { /** */ public SingularMatrixException() { super("Regular (or non-singular) matrix expected."); } + + /** */ + public SingularMatrixException(String format, Object... args) { + super(format, args); + } } http://git-wip-us.apache.org/repos/asf/ignite/blob/934f6ac2/modules/ml/src/main/java/org/apache/ignite/ml/math/functions/Functions.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/math/functions/Functions.java b/modules/ml/src/main/java/org/apache/ignite/ml/math/functions/Functions.java index 2f97067..e86a5eb 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/math/functions/Functions.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/math/functions/Functions.java @@ -133,6 +133,11 @@ public final class Functions { return (a, b) -> a - b * constant; } + /** Function that returns passed constant. */ + public static IgniteDoubleFunction<Double> constant(Double c) { + return a -> c; + } + /** * Function that returns {@code Math.pow(a, b)}. * http://git-wip-us.apache.org/repos/asf/ignite/blob/934f6ac2/modules/ml/src/main/java/org/apache/ignite/ml/math/util/MatrixUtil.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/math/util/MatrixUtil.java b/modules/ml/src/main/java/org/apache/ignite/ml/math/util/MatrixUtil.java new file mode 100644 index 0000000..9277ae4 --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/math/util/MatrixUtil.java @@ -0,0 +1,121 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.math.util; + +import org.apache.ignite.ml.math.Matrix; +import org.apache.ignite.ml.math.Vector; +import org.apache.ignite.ml.math.impls.matrix.CacheMatrix; +import org.apache.ignite.ml.math.impls.matrix.DenseLocalOnHeapMatrix; +import org.apache.ignite.ml.math.impls.matrix.MatrixView; +import org.apache.ignite.ml.math.impls.matrix.PivotedMatrixView; +import org.apache.ignite.ml.math.impls.matrix.RandomMatrix; +import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector; + +/** + * Utility class for various matrix operations. + */ +public class MatrixUtil { + /** + * Create the like matrix with read-only matrices support. + * + * @param matrix Matrix for like. + * @return Like matrix. + */ + public static Matrix like(Matrix matrix) { + if (isCopyLikeSupport(matrix)) + return new DenseLocalOnHeapMatrix(matrix.rowSize(), matrix.columnSize()); + else + return matrix.like(matrix.rowSize(), matrix.columnSize()); + } + + /** + * Create the identity matrix like a given matrix. + * + * @param matrix Matrix for like. + * @return Identity matrix. + */ + public static Matrix identityLike(Matrix matrix, int n) { + Matrix res = like(matrix, n, n); + // TODO: Maybe we should introduce API for walking(and changing) matrix in + // a fastest possible visiting order. + for (int i = 0; i < n; i++) + res.setX(i, i, 1.0); + return res; + } + + /** + * Create the like matrix with specified size with read-only matrices support. + * + * @param matrix Matrix for like. + * @return Like matrix. + */ + public static Matrix like(Matrix matrix, int rows, int cols) { + if (isCopyLikeSupport(matrix)) + return new DenseLocalOnHeapMatrix(rows, cols); + else + return matrix.like(rows, cols); + } + + /** + * Create the like vector with read-only matrices support. + * + * @param matrix Matrix for like. + * @param crd Cardinality of the vector. + * @return Like vector. + */ + public static Vector likeVector(Matrix matrix, int crd) { + if (isCopyLikeSupport(matrix)) + return new DenseLocalOnHeapVector(crd); + else + return matrix.likeVector(crd); + } + + /** + * Create the like vector with read-only matrices support. + * + * @param matrix Matrix for like. + * @return Like vector. + */ + public static Vector likeVector(Matrix matrix) { + return likeVector(matrix, matrix.rowSize()); + } + + /** + * Create the copy of matrix with read-only matrices support. + * + * @param matrix Matrix for copy. + * @return Copy. + */ + public static Matrix copy(Matrix matrix) { + if (isCopyLikeSupport(matrix)) { + DenseLocalOnHeapMatrix cp = new DenseLocalOnHeapMatrix(matrix.rowSize(), matrix.columnSize()); + + cp.assign(matrix); + + return cp; + } + else + return matrix.copy(); + } + + /** */ + private static boolean isCopyLikeSupport(Matrix matrix) { + return matrix instanceof RandomMatrix || matrix instanceof MatrixView || matrix instanceof CacheMatrix || + matrix instanceof PivotedMatrixView; + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/934f6ac2/modules/ml/src/main/java/org/apache/ignite/ml/regressions/AbstractMultipleLinearRegression.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/regressions/AbstractMultipleLinearRegression.java b/modules/ml/src/main/java/org/apache/ignite/ml/regressions/AbstractMultipleLinearRegression.java new file mode 100644 index 0000000..d558dc0 --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/regressions/AbstractMultipleLinearRegression.java @@ -0,0 +1,358 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.regressions; + +import org.apache.ignite.ml.math.Matrix; +import org.apache.ignite.ml.math.Vector; +import org.apache.ignite.ml.math.exceptions.CardinalityException; +import org.apache.ignite.ml.math.exceptions.InsufficientDataException; +import org.apache.ignite.ml.math.exceptions.MathIllegalArgumentException; +import org.apache.ignite.ml.math.exceptions.NoDataException; +import org.apache.ignite.ml.math.exceptions.NonSquareMatrixException; +import org.apache.ignite.ml.math.exceptions.NullArgumentException; +import org.apache.ignite.ml.math.functions.Functions; +import org.apache.ignite.ml.math.util.MatrixUtil; + +/** + * This class is based on the corresponding class from Apache Common Math lib. + * Abstract base class for implementations of MultipleLinearRegression. + */ +public abstract class AbstractMultipleLinearRegression implements MultipleLinearRegression { + /** X sample data. */ + private Matrix xMatrix; + + /** Y sample data. */ + private Vector yVector; + + /** Whether or not the regression model includes an intercept. True means no intercept. */ + private boolean noIntercept = false; + + /** + * @return the X sample data. + */ + protected Matrix getX() { + return xMatrix; + } + + /** + * @return the Y sample data. + */ + protected Vector getY() { + return yVector; + } + + /** + * @return true if the model has no intercept term; false otherwise + */ + public boolean isNoIntercept() { + return noIntercept; + } + + /** + * @param noIntercept true means the model is to be estimated without an intercept term + */ + public void setNoIntercept(boolean noIntercept) { + this.noIntercept = noIntercept; + } + + /** + * <p>Loads model x and y sample data from a flat input array, overriding any previous sample. + * </p> + * <p>Assumes that rows are concatenated with y values first in each row. For example, an input + * <code>data</code> array containing the sequence of values (1, 2, 3, 4, 5, 6, 7, 8, 9) with + * <code>nobs = 3</code> and <code>nvars = 2</code> creates a regression dataset with two + * independent variables, as below: + * <pre> + * y x[0] x[1] + * -------------- + * 1 2 3 + * 4 5 6 + * 7 8 9 + * </pre> + * </p> + * <p>Note that there is no need to add an initial unitary column (column of 1's) when + * specifying a model including an intercept term. If {@link #isNoIntercept()} is <code>true</code>, + * the X matrix will be created without an initial column of "1"s; otherwise this column will + * be added. + * </p> + * <p>Throws IllegalArgumentException if any of the following preconditions fail: + * <ul><li><code>data</code> cannot be null</li> + * <li><code>data.length = nobs * (nvars + 1)</li> + * <li><code>nobs > nvars</code></li></ul> + * </p> + * + * @param data input data array + * @param nobs number of observations (rows) + * @param nvars number of independent variables (columns, not counting y) + * @param like matrix(maybe empty) indicating how data should be stored + * @throws NullArgumentException if the data array is null + * @throws CardinalityException if the length of the data array is not equal to <code>nobs * (nvars + 1)</code> + * @throws InsufficientDataException if <code>nobs</code> is less than <code>nvars + 1</code> + */ + public void newSampleData(double[] data, int nobs, int nvars, Matrix like) { + if (data == null) + throw new NullArgumentException(); + if (data.length != nobs * (nvars + 1)) + throw new CardinalityException(nobs * (nvars + 1), data.length); + if (nobs <= nvars) + throw new InsufficientDataException(RegressionsErrorMessages.INSUFFICIENT_OBSERVED_POINTS_IN_SAMPLE); + double[] y = new double[nobs]; + final int cols = noIntercept ? nvars : nvars + 1; + double[][] x = new double[nobs][cols]; + int pointer = 0; + for (int i = 0; i < nobs; i++) { + y[i] = data[pointer++]; + if (!noIntercept) + x[i][0] = 1.0d; + for (int j = noIntercept ? 0 : 1; j < cols; j++) + x[i][j] = data[pointer++]; + } + xMatrix = MatrixUtil.like(like, nobs, cols).assign(x); + yVector = MatrixUtil.likeVector(like, y.length).assign(y); + } + + /** + * Loads new y sample data, overriding any previous data. + * + * @param y the array representing the y sample + * @throws NullArgumentException if y is null + * @throws NoDataException if y is empty + */ + protected void newYSampleData(Vector y) { + if (y == null) + throw new NullArgumentException(); + if (y.size() == 0) + throw new NoDataException(); + // TODO: Should we copy here? + yVector = y; + } + + /** + * <p>Loads new x sample data, overriding any previous data. + * </p> + * The input <code>x</code> array should have one row for each sample + * observation, with columns corresponding to independent variables. + * For example, if <pre> + * <code> x = new double[][] {{1, 2}, {3, 4}, {5, 6}} </code></pre> + * then <code>setXSampleData(x) </code> results in a model with two independent + * variables and 3 observations: + * <pre> + * x[0] x[1] + * ---------- + * 1 2 + * 3 4 + * 5 6 + * </pre> + * </p> + * <p>Note that there is no need to add an initial unitary column (column of 1's) when + * specifying a model including an intercept term. + * </p> + * + * @param x the rectangular array representing the x sample + * @throws NullArgumentException if x is null + * @throws NoDataException if x is empty + * @throws CardinalityException if x is not rectangular + */ + protected void newXSampleData(Matrix x) { + if (x == null) + throw new NullArgumentException(); + if (x.rowSize() == 0) + throw new NoDataException(); + if (noIntercept) + // TODO: Should we copy here? + xMatrix = x; + else { // Augment design matrix with initial unitary column + xMatrix = MatrixUtil.like(x, x.rowSize(), x.columnSize() + 1); + xMatrix.viewColumn(0).map(Functions.constant(1.0)); + xMatrix.viewPart(0, x.rowSize(), 1, x.columnSize()).assign(x); + } + } + + /** + * Validates sample data. Checks that + * <ul><li>Neither x nor y is null or empty;</li> + * <li>The length (i.e. number of rows) of x equals the length of y</li> + * <li>x has at least one more row than it has columns (i.e. there is + * sufficient data to estimate regression coefficients for each of the + * columns in x plus an intercept.</li> + * </ul> + * + * @param x the n x k matrix representing the x data + * @param y the n-sized vector representing the y data + * @throws NullArgumentException if {@code x} or {@code y} is null + * @throws CardinalityException if {@code x} and {@code y} do not have the same length + * @throws NoDataException if {@code x} or {@code y} are zero-length + * @throws MathIllegalArgumentException if the number of rows of {@code x} is not larger than the number of columns + * + 1 + */ + protected void validateSampleData(Matrix x, Vector y) throws MathIllegalArgumentException { + if ((x == null) || (y == null)) + throw new NullArgumentException(); + if (x.rowSize() != y.size()) + throw new CardinalityException(y.size(), x.rowSize()); + if (x.rowSize() == 0) { // Must be no y data either + throw new NoDataException(); + } + if (x.columnSize() + 1 > x.rowSize()) { + throw new MathIllegalArgumentException( + RegressionsErrorMessages.NOT_ENOUGH_DATA_FOR_NUMBER_OF_PREDICTORS, + x.rowSize(), x.columnSize()); + } + } + + /** + * Validates that the x data and covariance matrix have the same + * number of rows and that the covariance matrix is square. + * + * @param x the [n,k] array representing the x sample + * @param covariance the [n,n] array representing the covariance matrix + * @throws CardinalityException if the number of rows in x is not equal to the number of rows in covariance + * @throws NonSquareMatrixException if the covariance matrix is not square + */ + protected void validateCovarianceData(double[][] x, double[][] covariance) { + if (x.length != covariance.length) + throw new CardinalityException(x.length, covariance.length); + if (covariance.length > 0 && covariance.length != covariance[0].length) + throw new NonSquareMatrixException(covariance.length, covariance[0].length); + } + + /** + * {@inheritDoc} + */ + @Override public double[] estimateRegressionParameters() { + Vector b = calculateBeta(); + return b.getStorage().data(); + } + + /** + * {@inheritDoc} + */ + @Override public double[] estimateResiduals() { + Vector b = calculateBeta(); + Vector e = yVector.minus(xMatrix.times(b)); + return e.getStorage().data(); + } + + /** + * {@inheritDoc} + */ + @Override public Matrix estimateRegressionParametersVariance() { + return calculateBetaVariance(); + } + + /** + * {@inheritDoc} + */ + @Override public double[] estimateRegressionParametersStandardErrors() { + Matrix betaVariance = estimateRegressionParametersVariance(); + double sigma = calculateErrorVariance(); + int len = betaVariance.rowSize(); + double[] res = new double[len]; + for (int i = 0; i < len; i++) + res[i] = Math.sqrt(sigma * betaVariance.getX(i, i)); + return res; + } + + /** + * {@inheritDoc} + */ + @Override public double estimateRegressandVariance() { + return calculateYVariance(); + } + + /** + * Estimates the variance of the error. + * + * @return estimate of the error variance + */ + public double estimateErrorVariance() { + return calculateErrorVariance(); + + } + + /** + * Estimates the standard error of the regression. + * + * @return regression standard error + */ + public double estimateRegressionStandardError() { + return Math.sqrt(estimateErrorVariance()); + } + + /** + * Calculates the beta of multiple linear regression in matrix notation. + * + * @return beta + */ + protected abstract Vector calculateBeta(); + + /** + * Calculates the beta variance of multiple linear regression in matrix + * notation. + * + * @return beta variance + */ + protected abstract Matrix calculateBetaVariance(); + + /** + * Calculates the variance of the y values. + * + * @return Y variance + */ + protected double calculateYVariance() { + // Compute initial estimate using definitional formula + int vSize = yVector.size(); + double xbar = yVector.sum() / vSize; + // Compute correction factor in second pass + final double corr = yVector.foldMap((val, acc) -> acc + val - xbar, Functions.IDENTITY, 0.0); + final double mean = xbar - corr; + return yVector.foldMap(Functions.PLUS, val -> (val - mean) * (val - mean), 0.0) / (vSize - 1); + } + + /** + * <p>Calculates the variance of the error term.</p> + * Uses the formula <pre> + * var(u) = u · u / (n - k) + * </pre> + * where n and k are the row and column dimensions of the design + * matrix X. + * + * @return error variance estimate + */ + protected double calculateErrorVariance() { + Vector residuals = calculateResiduals(); + return residuals.dot(residuals) / + (xMatrix.rowSize() - xMatrix.columnSize()); + } + + /** + * Calculates the residuals of multiple linear regression in matrix + * notation. + * + * <pre> + * u = y - X * b + * </pre> + * + * @return The residuals [n,1] matrix + */ + protected Vector calculateResiduals() { + Vector b = calculateBeta(); + return yVector.minus(xMatrix.times(b)); + } + +} http://git-wip-us.apache.org/repos/asf/ignite/blob/934f6ac2/modules/ml/src/main/java/org/apache/ignite/ml/regressions/MultipleLinearRegression.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/regressions/MultipleLinearRegression.java b/modules/ml/src/main/java/org/apache/ignite/ml/regressions/MultipleLinearRegression.java new file mode 100644 index 0000000..2fc4dde --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/regressions/MultipleLinearRegression.java @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.regressions; + +import org.apache.ignite.ml.math.Matrix; + +/** + * This class is based on the corresponding class from Apache Common Math lib. * The multiple linear regression can be + * represented in matrix-notation. + * <pre> + * y=X*b+u + * </pre> + * where y is an <code>n-vector</code> <b>regressand</b>, X is a <code>[n,k]</code> matrix whose <code>k</code> columns + * are called <b>regressors</b>, b is <code>k-vector</code> of <b>regression parameters</b> and <code>u</code> is an + * <code>n-vector</code> of <b>error terms</b> or <b>residuals</b>. + * <p> + * The notation is quite standard in literature, cf eg <a href="http://www.econ.queensu.ca/ETM">Davidson and MacKinnon, + * Econometrics Theory and Methods, 2004</a>. </p> + */ +public interface MultipleLinearRegression { + /** + * Estimates the regression parameters b. + * + * @return The [k,1] array representing b + */ + public double[] estimateRegressionParameters(); + + /** + * Estimates the variance of the regression parameters, ie Var(b). + * + * @return The k x k matrix representing the variance of b + */ + public Matrix estimateRegressionParametersVariance(); + + /** + * Estimates the residuals, ie u = y - X*b. + * + * @return The [n,1] array representing the residuals + */ + public double[] estimateResiduals(); + + /** + * Returns the variance of the regressand, ie Var(y). + * + * @return The double representing the variance of y + */ + public double estimateRegressandVariance(); + + /** + * Returns the standard errors of the regression parameters. + * + * @return standard errors of estimated regression parameters + */ + public double[] estimateRegressionParametersStandardErrors(); + +}