This is an automated email from the ASF dual-hosted git repository.

joshrosen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new fc65e0f  [SPARK-27839][SQL] Change UTF8String.replace() to operate on 
UTF8 bytes
fc65e0f is described below

commit fc65e0fe2c8a114feba47d8f7b63628a676dd24c
Author: Josh Rosen <rosenvi...@gmail.com>
AuthorDate: Wed Jun 19 15:21:26 2019 -0700

    [SPARK-27839][SQL] Change UTF8String.replace() to operate on UTF8 bytes
    
    ## What changes were proposed in this pull request?
    
    This PR significantly improves the performance of `UTF8String.replace()` by 
performing direct replacement over UTF8 bytes instead of decoding those bytes 
into Java Strings.
    
    In cases where the search string is not found (i.e. no replacements are 
performed, a case which I expect to be common) this new implementation performs 
no object allocation or memory copying.
    
    My implementation is modeled after `commons-lang3`'s 
`StringUtils.replace()` method. As part of my implementation, I needed a 
StringBuilder / resizable buffer, so I moved `UTF8StringBuilder` from the 
`catalyst` package to `unsafe`.
    
    ## How was this patch tested?
    
    Copied tests from `StringExpressionSuite` to `UTF8StringSuite` and added a 
couple of new cases.
    
    To evaluate performance, I did some quick local benchmarking by running the 
following code in `spark-shell` (with Java 1.8.0_191):
    
    ```scala
    import org.apache.spark.unsafe.types.UTF8String
    
    def benchmark(text: String, search: String, replace: String) {
      val utf8Text = UTF8String.fromString(text)
      val utf8Search = UTF8String.fromString(search)
      val utf8Replace = UTF8String.fromString(replace)
    
      val start = System.currentTimeMillis
      var i = 0
      while (i < 1000 * 1000 * 100) {
        utf8Text.replace(utf8Search, utf8Replace)
        i += 1
      }
      val end = System.currentTimeMillis
    
      println(end - start)
    }
    
    benchmark("ABCDEFGH", "DEF", "ZZZZ")  // replacement occurs
    benchmark("ABCDEFGH", "Z", "")  // no replacement occurs
    ```
    
    On my laptop this took ~54 / ~40 seconds seconds before this patch's 
changes and ~6.5 / ~3.8 seconds afterwards.
    
    Closes #24707 from JoshRosen/faster-string-replace.
    
    Authored-by: Josh Rosen <rosenvi...@gmail.com>
    Signed-off-by: Josh Rosen <rosenvi...@gmail.com>
---
 .../apache/spark/unsafe}/UTF8StringBuilder.java    | 27 +++++++++++++--
 .../org/apache/spark/unsafe/types/UTF8String.java  | 26 ++++++++++++---
 .../apache/spark/unsafe/types/UTF8StringSuite.java | 38 ++++++++++++++++++++++
 .../spark/sql/catalyst/expressions/Cast.scala      |  1 +
 .../expressions/collectionOperations.scala         |  1 +
 5 files changed, 86 insertions(+), 7 deletions(-)

diff --git 
a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UTF8StringBuilder.java
 b/common/unsafe/src/main/java/org/apache/spark/unsafe/UTF8StringBuilder.java
similarity index 80%
rename from 
sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UTF8StringBuilder.java
rename to 
common/unsafe/src/main/java/org/apache/spark/unsafe/UTF8StringBuilder.java
index f0f66ba..481ea89 100644
--- 
a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UTF8StringBuilder.java
+++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/UTF8StringBuilder.java
@@ -15,9 +15,8 @@
  * limitations under the License.
  */
 
-package org.apache.spark.sql.catalyst.expressions.codegen;
+package org.apache.spark.unsafe;
 
-import org.apache.spark.unsafe.Platform;
 import org.apache.spark.unsafe.array.ByteArrayMethods;
 import org.apache.spark.unsafe.types.UTF8String;
 
@@ -34,7 +33,18 @@ public class UTF8StringBuilder {
 
   public UTF8StringBuilder() {
     // Since initial buffer size is 16 in `StringBuilder`, we set the same 
size here
-    this.buffer = new byte[16];
+    this(16);
+  }
+
+  public UTF8StringBuilder(int initialSize) {
+    if (initialSize < 0) {
+      throw new IllegalArgumentException("Size must be non-negative");
+    }
+    if (initialSize > ARRAY_MAX) {
+      throw new IllegalArgumentException(
+        "Size " + initialSize + " exceeded maximum size of " + ARRAY_MAX);
+    }
+    this.buffer = new byte[initialSize];
   }
 
   // Grows the buffer by at least `neededSize`
@@ -72,6 +82,17 @@ public class UTF8StringBuilder {
     append(UTF8String.fromString(value));
   }
 
+  public void appendBytes(Object base, long offset, int length) {
+    grow(length);
+    Platform.copyMemory(
+      base,
+      offset,
+      buffer,
+      cursor,
+      length);
+    cursor += length;
+  }
+
   public UTF8String build() {
     return UTF8String.fromBytes(buffer, 0, totalSize());
   }
diff --git 
a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java 
b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
index 0550127..30b884c 100644
--- a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
+++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
@@ -32,6 +32,7 @@ import com.esotericsoftware.kryo.io.Output;
 import com.google.common.primitives.Ints;
 
 import org.apache.spark.unsafe.Platform;
+import org.apache.spark.unsafe.UTF8StringBuilder;
 import org.apache.spark.unsafe.array.ByteArrayMethods;
 import org.apache.spark.unsafe.hash.Murmur3_x86_32;
 
@@ -1002,12 +1003,29 @@ public final class UTF8String implements 
Comparable<UTF8String>, Externalizable,
   }
 
   public UTF8String replace(UTF8String search, UTF8String replace) {
-    if (EMPTY_UTF8.equals(search)) {
+    // This implementation is loosely based on commons-lang3's 
StringUtils.replace().
+    if (numBytes == 0 || search.numBytes == 0) {
       return this;
     }
-    String replaced = toString().replace(
-      search.toString(), replace.toString());
-    return fromString(replaced);
+    // Find the first occurrence of the search string.
+    int start = 0;
+    int end = this.find(search, start);
+    if (end == -1) {
+      // Search string was not found, so string is unchanged.
+      return this;
+    }
+    // At least one match was found. Estimate space needed for result.
+    // The 16x multiplier here is chosen to match commons-lang3's 
implementation.
+    int increase = Math.max(0, replace.numBytes - search.numBytes) * 16;
+    final UTF8StringBuilder buf = new UTF8StringBuilder(numBytes + increase);
+    while (end != -1) {
+      buf.appendBytes(this.base, this.offset + start, end - start);
+      buf.append(replace);
+      start = end + search.numBytes;
+      end = this.find(search, start);
+    }
+    buf.appendBytes(this.base, this.offset + start, numBytes - start);
+    return buf.build();
   }
 
   // TODO: Need to use `Code Point` here instead of Char in case the character 
longer than 2 bytes
diff --git 
a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java
 
b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java
index cf9cc6b..bc75fa9 100644
--- 
a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java
+++ 
b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java
@@ -404,6 +404,44 @@ public class UTF8StringSuite {
   }
 
   @Test
+  public void replace() {
+    assertEquals(
+      fromString("re123ace"),
+      fromString("replace").replace(fromString("pl"), fromString("123")));
+    assertEquals(
+      fromString("reace"),
+      fromString("replace").replace(fromString("pl"), fromString("")));
+    assertEquals(
+      fromString("replace"),
+      fromString("replace").replace(fromString(""), fromString("123")));
+    // tests for multiple replacements
+    assertEquals(
+      fromString("a12ca12c"),
+      fromString("abcabc").replace(fromString("b"), fromString("12")));
+    assertEquals(
+      fromString("adad"),
+      fromString("abcdabcd").replace(fromString("bc"), fromString("")));
+    // tests for single character search and replacement strings
+    assertEquals(
+      fromString("AbcAbc"),
+      fromString("abcabc").replace(fromString("a"), fromString("A")));
+    assertEquals(
+      fromString("abcabc"),
+      fromString("abcabc").replace(fromString("Z"), fromString("A")));
+    // Tests with non-ASCII characters
+    assertEquals(
+      fromString("花ab界"),
+      fromString("花花世界").replace(fromString("花世"), fromString("ab")));
+    assertEquals(
+      fromString("a水c"),
+      fromString("a火c").replace(fromString("火"), fromString("水")));
+    // Tests for a large number of replacements, triggering UTF8StringBuilder 
resize
+    assertEquals(
+      fromString("abcd").repeat(17),
+      fromString("a").repeat(17).replace(fromString("a"), fromString("abcd")));
+  }
+
+  @Test
   public void levenshteinDistance() {
     assertEquals(0, EMPTY_UTF8.levenshteinDistance(EMPTY_UTF8));
     assertEquals(1, EMPTY_UTF8.levenshteinDistance(fromString("a")));
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
index f8c1102..9691288 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
@@ -29,6 +29,7 @@ import 
org.apache.spark.sql.catalyst.expressions.codegen.Block._
 import org.apache.spark.sql.catalyst.util._
 import org.apache.spark.sql.catalyst.util.DateTimeUtils._
 import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.UTF8StringBuilder
 import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
 import org.apache.spark.unsafe.types.UTF8String.{IntWrapper, LongWrapper}
 
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
index 41d9b06..8477e63 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
@@ -30,6 +30,7 @@ import org.apache.spark.sql.catalyst.util._
 import org.apache.spark.sql.catalyst.util.DateTimeUtils._
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.UTF8StringBuilder
 import org.apache.spark.unsafe.array.ByteArrayMethods
 import org.apache.spark.unsafe.array.ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH
 import org.apache.spark.unsafe.types.{ByteArray, UTF8String}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to