Repository: spark Updated Branches: refs/heads/master cdaa562c9 -> 0ebf7c1bf
[SPARK-17027][ML] Avoid integer overflow in PolynomialExpansion.getPolySize ## What changes were proposed in this pull request? Replaces custom choose function with o.a.commons.math3.CombinatoricsUtils.binomialCoefficient ## How was this patch tested? Spark unit tests Author: zero323 <zero...@users.noreply.github.com> Closes #14614 from zero323/SPARK-17027. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/0ebf7c1b Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/0ebf7c1b Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/0ebf7c1b Branch: refs/heads/master Commit: 0ebf7c1bff736cf54ec47957d71394d5b75b47a7 Parents: cdaa562 Author: zero323 <zero...@users.noreply.github.com> Authored: Sun Aug 14 11:59:24 2016 +0100 Committer: Sean Owen <so...@cloudera.com> Committed: Sun Aug 14 11:59:24 2016 +0100 ---------------------------------------------------------------------- .../spark/ml/feature/PolynomialExpansion.scala | 10 ++++---- .../ml/feature/PolynomialExpansionSuite.scala | 24 ++++++++++++++++++++ 2 files changed, 30 insertions(+), 4 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/0ebf7c1b/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala index 72fb35b..6e872c1 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala @@ -19,6 +19,8 @@ package org.apache.spark.ml.feature import scala.collection.mutable +import org.apache.commons.math3.util.CombinatoricsUtils + import org.apache.spark.annotation.Since import org.apache.spark.ml.UnaryTransformer import org.apache.spark.ml.linalg._ @@ -84,12 +86,12 @@ class PolynomialExpansion @Since("1.4.0") (@Since("1.4.0") override val uid: Str @Since("1.6.0") object PolynomialExpansion extends DefaultParamsReadable[PolynomialExpansion] { - private def choose(n: Int, k: Int): Int = { - Range(n, n - k, -1).product / Range(k, 1, -1).product + private def getPolySize(numFeatures: Int, degree: Int): Int = { + val n = CombinatoricsUtils.binomialCoefficient(numFeatures + degree, degree) + require(n <= Integer.MAX_VALUE) + n.toInt } - private def getPolySize(numFeatures: Int, degree: Int): Int = choose(numFeatures + degree, degree) - private def expandDense( values: Array[Double], lastIdx: Int, http://git-wip-us.apache.org/repos/asf/spark/blob/0ebf7c1b/mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala ---------------------------------------------------------------------- diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala index 8e1f9dd..9ecd321 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala @@ -116,5 +116,29 @@ class PolynomialExpansionSuite .setDegree(3) testDefaultReadWrite(t) } + + test("SPARK-17027. Integer overflow in PolynomialExpansion.getPolySize") { + val data: Array[(Vector, Int, Int)] = Array( + (Vectors.dense(1.0, 2.0, 3.0, 4.0, 5.0), 3002, 4367), + (Vectors.sparse(5, Seq((0, 1.0), (4, 5.0))), 3002, 4367), + (Vectors.dense(1.0, 2.0, 3.0, 4.0, 5.0, 6.0), 8007, 12375) + ) + + val df = spark.createDataFrame(data) + .toDF("features", "expectedPoly10size", "expectedPoly11size") + + val t = new PolynomialExpansion() + .setInputCol("features") + .setOutputCol("polyFeatures") + + for (i <- Seq(10, 11)) { + val transformed = t.setDegree(i) + .transform(df) + .select(s"expectedPoly${i}size", "polyFeatures") + .rdd.map { case Row(expected: Int, v: Vector) => expected == v.size } + + assert(transformed.collect.forall(identity)) + } + } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org