Repository: spark
Updated Branches:
  refs/heads/branch-2.3 d42610440 -> 9d63e540e


[SPARK-24216][SQL] Spark TypedAggregateExpression uses getSimpleName that is 
not safe in scala

When user create a aggregator object in scala and pass the aggregator to Spark 
Dataset's agg() method, Spark's will initialize TypedAggregateExpression with 
the nodeName field as aggregator.getClass.getSimpleName. However, getSimpleName 
is not safe in scala environment, depending on how user creates the aggregator 
object. For example, if the aggregator class full qualified name is 
"com.my.company.MyUtils$myAgg$2$", the getSimpleName will throw 
java.lang.InternalError "Malformed class name". This has been reported in 
scalatest https://github.com/scalatest/scalatest/pull/1044 and discussed in 
many scala upstream jiras such as SI-8110, SI-5425.

To fix this issue, we follow the solution in 
https://github.com/scalatest/scalatest/pull/1044 to add safer version of 
getSimpleName as a util method, and TypedAggregateExpression will invoke this 
util method rather than getClass.getSimpleName.

added unit test

Author: Fangshi Li <f...@linkedin.com>

Closes #21276 from fangshil/SPARK-24216.

(cherry picked from commit cc88d7fad16e8b5cbf7b6b9bfe412908782b4a45)
Signed-off-by: Wenchen Fan <wenc...@databricks.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/9d63e540
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/9d63e540
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/9d63e540

Branch: refs/heads/branch-2.3
Commit: 9d63e540e00bc655faf6d8fe1d0035bc0b9a9192
Parents: d426104
Author: Fangshi Li <f...@linkedin.com>
Authored: Tue Jun 12 12:10:08 2018 -0700
Committer: Wenchen Fan <wenc...@databricks.com>
Committed: Fri Jun 15 20:23:05 2018 -0700

----------------------------------------------------------------------
 .../org/apache/spark/util/AccumulatorV2.scala   |  6 +-
 .../scala/org/apache/spark/util/Utils.scala     | 59 +++++++++++++++++++-
 .../org/apache/spark/util/UtilsSuite.scala      | 16 ++++++
 .../apache/spark/ml/util/Instrumentation.scala  |  5 +-
 .../aggregate/TypedAggregateExpression.scala    |  5 +-
 5 files changed, 86 insertions(+), 5 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/9d63e540/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala 
b/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala
index 3b469a6..bf618b4 100644
--- a/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala
+++ b/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala
@@ -200,10 +200,12 @@ abstract class AccumulatorV2[IN, OUT] extends 
Serializable {
   }
 
   override def toString: String = {
+    // getClass.getSimpleName can cause Malformed class name error,
+    // call safer `Utils.getSimpleName` instead
     if (metadata == null) {
-      "Un-registered Accumulator: " + getClass.getSimpleName
+      "Un-registered Accumulator: " + Utils.getSimpleName(getClass)
     } else {
-      getClass.getSimpleName + s"(id: $id, name: $name, value: $value)"
+      Utils.getSimpleName(getClass) + s"(id: $id, name: $name, value: $value)"
     }
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/9d63e540/core/src/main/scala/org/apache/spark/util/Utils.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala 
b/core/src/main/scala/org/apache/spark/util/Utils.scala
index 12d0934..d4b72e8 100644
--- a/core/src/main/scala/org/apache/spark/util/Utils.scala
+++ b/core/src/main/scala/org/apache/spark/util/Utils.scala
@@ -19,6 +19,7 @@ package org.apache.spark.util
 
 import java.io._
 import java.lang.{Byte => JByte}
+import java.lang.InternalError
 import java.lang.management.{LockInfo, ManagementFactory, MonitorInfo, 
ThreadInfo}
 import java.lang.reflect.InvocationTargetException
 import java.math.{MathContext, RoundingMode}
@@ -1875,7 +1876,7 @@ private[spark] object Utils extends Logging {
 
   /** Return the class name of the given object, removing all dollar signs */
   def getFormattedClassName(obj: AnyRef): String = {
-    obj.getClass.getSimpleName.replace("$", "")
+    getSimpleName(obj.getClass).replace("$", "")
   }
 
   /** Return an option that translates JNothing to None */
@@ -2817,6 +2818,62 @@ private[spark] object Utils extends Logging {
     HashCodes.fromBytes(secretBytes).toString()
   }
 
+  /**
+   * Safer than Class obj's getSimpleName which may throw Malformed class name 
error in scala.
+   * This method mimicks scalatest's getSimpleNameOfAnObjectsClass.
+   */
+  def getSimpleName(cls: Class[_]): String = {
+    try {
+      return cls.getSimpleName
+    } catch {
+      case err: InternalError => return 
stripDollars(stripPackages(cls.getName))
+    }
+  }
+
+  /**
+   * Remove the packages from full qualified class name
+   */
+  private def stripPackages(fullyQualifiedName: String): String = {
+    fullyQualifiedName.split("\\.").takeRight(1)(0)
+  }
+
+  /**
+   * Remove trailing dollar signs from qualified class name,
+   * and return the trailing part after the last dollar sign in the middle
+   */
+  private def stripDollars(s: String): String = {
+    val lastDollarIndex = s.lastIndexOf('$')
+    if (lastDollarIndex < s.length - 1) {
+      // The last char is not a dollar sign
+      if (lastDollarIndex == -1 || !s.contains("$iw")) {
+        // The name does not have dollar sign or is not an intepreter
+        // generated class, so we should return the full string
+        s
+      } else {
+        // The class name is intepreter generated,
+        // return the part after the last dollar sign
+        // This is the same behavior as getClass.getSimpleName
+        s.substring(lastDollarIndex + 1)
+      }
+    }
+    else {
+      // The last char is a dollar sign
+      // Find last non-dollar char
+      val lastNonDollarChar = s.reverse.find(_ != '$')
+      lastNonDollarChar match {
+        case None => s
+        case Some(c) =>
+          val lastNonDollarIndex = s.lastIndexOf(c)
+          if (lastNonDollarIndex == -1) {
+            s
+          } else {
+            // Strip the trailing dollar signs
+            // Invoke stripDollars again to get the simple name
+            stripDollars(s.substring(0, lastNonDollarIndex + 1))
+          }
+      }
+    }
+  }
 }
 
 private[util] object CallerContext extends Logging {

http://git-wip-us.apache.org/repos/asf/spark/blob/9d63e540/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala 
b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala
index eaea6b0..cde250c 100644
--- a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala
@@ -1167,6 +1167,22 @@ class UtilsSuite extends SparkFunSuite with 
ResetSystemProperties with Logging {
       Utils.checkAndGetK8sMasterUrl("k8s://foo://host:port")
     }
   }
+
+  object MalformedClassObject {
+    class MalformedClass
+  }
+
+  test("Safe getSimpleName") {
+    // getSimpleName on class of MalformedClass will result in error: 
Malformed class name
+    // Utils.getSimpleName works
+    val err = intercept[java.lang.InternalError] {
+      classOf[MalformedClassObject.MalformedClass].getSimpleName
+    }
+    assert(err.getMessage === "Malformed class name")
+
+    assert(Utils.getSimpleName(classOf[MalformedClassObject.MalformedClass]) 
===
+      "UtilsSuite$MalformedClassObject$MalformedClass")
+  }
 }
 
 private class SimpleExtension

http://git-wip-us.apache.org/repos/asf/spark/blob/9d63e540/mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala 
b/mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala
index 7c46f45..8920e61 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala
@@ -28,6 +28,7 @@ import org.apache.spark.ml.{Estimator, Model}
 import org.apache.spark.ml.param.Param
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.Dataset
+import org.apache.spark.util.Utils
 
 /**
  * A small wrapper that defines a training session for an estimator, and some 
methods to log
@@ -44,7 +45,9 @@ private[spark] class Instrumentation[E <: Estimator[_]] 
private (
 
   private val id = Instrumentation.counter.incrementAndGet()
   private val prefix = {
-    val className = estimator.getClass.getSimpleName
+    // estimator.getClass.getSimpleName can cause Malformed class name error,
+    // call safer `Utils.getSimpleName` instead
+    val className = Utils.getSimpleName(estimator.getClass)
     s"$className-${estimator.uid}-${dataset.hashCode()}-$id: "
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/9d63e540/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala
index aab8cc5..6d44890 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala
@@ -29,6 +29,7 @@ import 
org.apache.spark.sql.catalyst.expressions.codegen.GenerateSafeProjection
 import org.apache.spark.sql.catalyst.expressions.objects.Invoke
 import org.apache.spark.sql.expressions.Aggregator
 import org.apache.spark.sql.types._
+import org.apache.spark.util.Utils
 
 object TypedAggregateExpression {
   def apply[BUF : Encoder, OUT : Encoder](
@@ -109,7 +110,9 @@ trait TypedAggregateExpression extends AggregateFunction {
     s"$nodeName($input)"
   }
 
-  override def nodeName: String = 
aggregator.getClass.getSimpleName.stripSuffix("$")
+  // aggregator.getClass.getSimpleName can cause Malformed class name error,
+  // call safer `Utils.getSimpleName` instead
+  override def nodeName: String = 
Utils.getSimpleName(aggregator.getClass).stripSuffix("$");
 }
 
 // TODO: merge these 2 implementations once we refactor the 
`AggregateFunction` interface.


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to