Repository: spark
Updated Branches:
  refs/heads/master dcae355c6 -> 6f0f1d9e0


[SPARK-12934][SQL] Count-min sketch serialization

This PR adds serialization support for `CountMinSketch`.

A version number is added to version the serialized binary format.

Author: Cheng Lian <l...@databricks.com>

Closes #10893 from liancheng/cms-serialization.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/6f0f1d9e
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/6f0f1d9e
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/6f0f1d9e

Branch: refs/heads/master
Commit: 6f0f1d9e04a8db47e2f6f8fcfe9dea9de0f633da
Parents: dcae355
Author: Cheng Lian <l...@databricks.com>
Authored: Mon Jan 25 15:05:05 2016 -0800
Committer: Reynold Xin <r...@databricks.com>
Committed: Mon Jan 25 15:05:05 2016 -0800

----------------------------------------------------------------------
 .../spark/util/sketch/CountMinSketch.java       |  32 ++++-
 .../spark/util/sketch/CountMinSketchImpl.java   | 129 +++++++++++++++++--
 .../util/sketch/IncompatibleMergeException.java |  24 ++++
 .../spark/util/sketch/CountMinSketchSuite.scala |  47 ++++++-
 4 files changed, 213 insertions(+), 19 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/6f0f1d9e/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java
----------------------------------------------------------------------
diff --git 
a/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java 
b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java
index 21b161b..6793864 100644
--- 
a/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java
+++ 
b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java
@@ -17,6 +17,7 @@
 
 package org.apache.spark.util.sketch;
 
+import java.io.IOException;
 import java.io.InputStream;
 import java.io.OutputStream;
 
@@ -55,6 +56,25 @@ import java.io.OutputStream;
  */
 abstract public class CountMinSketch {
   /**
+   * Version number of the serialized binary format.
+   */
+  public enum Version {
+    V1(1);
+
+    private final int versionNumber;
+
+    Version(int versionNumber) {
+      this.versionNumber = versionNumber;
+    }
+
+    public int getVersionNumber() {
+      return versionNumber;
+    }
+  }
+
+  public abstract Version version();
+
+  /**
    * Returns the relative error (or {@code eps}) of this {@link 
CountMinSketch}.
    */
   public abstract double relativeError();
@@ -99,19 +119,23 @@ abstract public class CountMinSketch {
    *
    * Note that only Count-Min sketches with the same {@code depth}, {@code 
width}, and random seed
    * can be merged.
+   *
+   * @exception IncompatibleMergeException if the {@code other} {@link 
CountMinSketch} has
+   *            incompatible depth, width, relative-error, confidence, or 
random seed.
    */
-  public abstract CountMinSketch mergeInPlace(CountMinSketch other);
+  public abstract CountMinSketch mergeInPlace(CountMinSketch other)
+      throws IncompatibleMergeException;
 
   /**
    * Writes out this {@link CountMinSketch} to an output stream in binary 
format.
    */
-  public abstract void writeTo(OutputStream out);
+  public abstract void writeTo(OutputStream out) throws IOException;
 
   /**
    * Reads in a {@link CountMinSketch} from an input stream.
    */
-  public static CountMinSketch readFrom(InputStream in) {
-    throw new UnsupportedOperationException("Not implemented yet");
+  public static CountMinSketch readFrom(InputStream in) throws IOException {
+    return CountMinSketchImpl.readFrom(in);
   }
 
   /**

http://git-wip-us.apache.org/repos/asf/spark/blob/6f0f1d9e/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java
----------------------------------------------------------------------
diff --git 
a/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java
 
b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java
index e9fdbe3..0209446 100644
--- 
a/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java
+++ 
b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java
@@ -17,11 +17,30 @@
 
 package org.apache.spark.util.sketch;
 
+import java.io.DataInputStream;
+import java.io.DataOutputStream;
+import java.io.IOException;
+import java.io.InputStream;
 import java.io.OutputStream;
 import java.io.UnsupportedEncodingException;
 import java.util.Arrays;
 import java.util.Random;
 
+/*
+ * Binary format of a serialized CountMinSketchImpl, version 1 (all values 
written in big-endian
+ * order):
+ *
+ * - Version number, always 1 (32 bit)
+ * - Total count of added items (64 bit)
+ * - Depth (32 bit)
+ * - Width (32 bit)
+ * - Hash functions (depth * 64 bit)
+ * - Count table
+ *   - Row 0 (width * 64 bit)
+ *   - Row 1 (width * 64 bit)
+ *   - ...
+ *   - Row depth - 1 (width * 64 bit)
+ */
 class CountMinSketchImpl extends CountMinSketch {
   public static final long PRIME_MODULUS = (1L << 31) - 1;
 
@@ -33,7 +52,7 @@ class CountMinSketchImpl extends CountMinSketch {
   private double eps;
   private double confidence;
 
-  public CountMinSketchImpl(int depth, int width, int seed) {
+  CountMinSketchImpl(int depth, int width, int seed) {
     this.depth = depth;
     this.width = width;
     this.eps = 2.0 / width;
@@ -41,7 +60,7 @@ class CountMinSketchImpl extends CountMinSketch {
     initTablesWith(depth, width, seed);
   }
 
-  public CountMinSketchImpl(double eps, double confidence, int seed) {
+  CountMinSketchImpl(double eps, double confidence, int seed) {
     // 2/w = eps ; w = 2/eps
     // 1/2^depth <= 1-confidence ; depth >= -log2 (1-confidence)
     this.eps = eps;
@@ -51,6 +70,53 @@ class CountMinSketchImpl extends CountMinSketch {
     initTablesWith(depth, width, seed);
   }
 
+  CountMinSketchImpl(int depth, int width, long totalCount, long hashA[], long 
table[][]) {
+    this.depth = depth;
+    this.width = width;
+    this.eps = 2.0 / width;
+    this.confidence = 1 - 1 / Math.pow(2, depth);
+    this.hashA = hashA;
+    this.table = table;
+    this.totalCount = totalCount;
+  }
+
+  @Override
+  public boolean equals(Object other) {
+    if (other == this) {
+      return true;
+    }
+
+    if (other == null || !(other instanceof CountMinSketchImpl)) {
+      return false;
+    }
+
+    CountMinSketchImpl that = (CountMinSketchImpl) other;
+
+    return
+      this.depth == that.depth &&
+      this.width == that.width &&
+      this.totalCount == that.totalCount &&
+      Arrays.equals(this.hashA, that.hashA) &&
+      Arrays.deepEquals(this.table, that.table);
+  }
+
+  @Override
+  public int hashCode() {
+    int hash = depth;
+
+    hash = hash * 31 + width;
+    hash = hash * 31 + (int) (totalCount ^ (totalCount >>> 32));
+    hash = hash * 31 + Arrays.hashCode(hashA);
+    hash = hash * 31 + Arrays.deepHashCode(table);
+
+    return hash;
+  }
+
+  @Override
+  public Version version() {
+    return Version.V1;
+  }
+
   private void initTablesWith(int depth, int width, int seed) {
     this.table = new long[depth][width];
     this.hashA = new long[depth];
@@ -221,27 +287,29 @@ class CountMinSketchImpl extends CountMinSketch {
   }
 
   @Override
-  public CountMinSketch mergeInPlace(CountMinSketch other) {
+  public CountMinSketch mergeInPlace(CountMinSketch other) throws 
IncompatibleMergeException {
     if (other == null) {
-      throw new CMSMergeException("Cannot merge null estimator");
+      throw new IncompatibleMergeException("Cannot merge null estimator");
     }
 
     if (!(other instanceof CountMinSketchImpl)) {
-      throw new CMSMergeException("Cannot merge estimator of class " + 
other.getClass().getName());
+      throw new IncompatibleMergeException(
+          "Cannot merge estimator of class " + other.getClass().getName()
+      );
     }
 
     CountMinSketchImpl that = (CountMinSketchImpl) other;
 
     if (this.depth != that.depth) {
-      throw new CMSMergeException("Cannot merge estimators of different 
depth");
+      throw new IncompatibleMergeException("Cannot merge estimators of 
different depth");
     }
 
     if (this.width != that.width) {
-      throw new CMSMergeException("Cannot merge estimators of different 
width");
+      throw new IncompatibleMergeException("Cannot merge estimators of 
different width");
     }
 
     if (!Arrays.equals(this.hashA, that.hashA)) {
-      throw new CMSMergeException("Cannot merge estimators of different seed");
+      throw new IncompatibleMergeException("Cannot merge estimators of 
different seed");
     }
 
     for (int i = 0; i < this.table.length; ++i) {
@@ -256,13 +324,48 @@ class CountMinSketchImpl extends CountMinSketch {
   }
 
   @Override
-  public void writeTo(OutputStream out) {
-    throw new UnsupportedOperationException("Not implemented yet");
+  public void writeTo(OutputStream out) throws IOException {
+    DataOutputStream dos = new DataOutputStream(out);
+
+    dos.writeInt(version().getVersionNumber());
+
+    dos.writeLong(this.totalCount);
+    dos.writeInt(this.depth);
+    dos.writeInt(this.width);
+
+    for (int i = 0; i < this.depth; ++i) {
+      dos.writeLong(this.hashA[i]);
+    }
+
+    for (int i = 0; i < this.depth; ++i) {
+      for (int j = 0; j < this.width; ++j) {
+        dos.writeLong(table[i][j]);
+      }
+    }
   }
 
-  protected static class CMSMergeException extends RuntimeException {
-    public CMSMergeException(String message) {
-      super(message);
+  public static CountMinSketchImpl readFrom(InputStream in) throws IOException 
{
+    DataInputStream dis = new DataInputStream(in);
+
+    // Ignores version number
+    dis.readInt();
+
+    long totalCount = dis.readLong();
+    int depth = dis.readInt();
+    int width = dis.readInt();
+
+    long hashA[] = new long[depth];
+    for (int i = 0; i < depth; ++i) {
+      hashA[i] = dis.readLong();
+    }
+
+    long table[][] = new long[depth][width];
+    for (int i = 0; i < depth; ++i) {
+      for (int j = 0; j < width; ++j) {
+        table[i][j] = dis.readLong();
+      }
     }
+
+    return new CountMinSketchImpl(depth, width, totalCount, hashA, table);
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/6f0f1d9e/common/sketch/src/main/java/org/apache/spark/util/sketch/IncompatibleMergeException.java
----------------------------------------------------------------------
diff --git 
a/common/sketch/src/main/java/org/apache/spark/util/sketch/IncompatibleMergeException.java
 
b/common/sketch/src/main/java/org/apache/spark/util/sketch/IncompatibleMergeException.java
new file mode 100644
index 0000000..64b567c
--- /dev/null
+++ 
b/common/sketch/src/main/java/org/apache/spark/util/sketch/IncompatibleMergeException.java
@@ -0,0 +1,24 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.sketch;
+
+public class IncompatibleMergeException extends Exception {
+  public IncompatibleMergeException(String message) {
+    super(message);
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/6f0f1d9e/common/sketch/src/test/scala/org/apache/spark/util/sketch/CountMinSketchSuite.scala
----------------------------------------------------------------------
diff --git 
a/common/sketch/src/test/scala/org/apache/spark/util/sketch/CountMinSketchSuite.scala
 
b/common/sketch/src/test/scala/org/apache/spark/util/sketch/CountMinSketchSuite.scala
index ec5b4ed..b9c7f5c 100644
--- 
a/common/sketch/src/test/scala/org/apache/spark/util/sketch/CountMinSketchSuite.scala
+++ 
b/common/sketch/src/test/scala/org/apache/spark/util/sketch/CountMinSketchSuite.scala
@@ -17,6 +17,8 @@
 
 package org.apache.spark.util.sketch
 
+import java.io.{ByteArrayInputStream, ByteArrayOutputStream}
+
 import scala.reflect.ClassTag
 import scala.util.Random
 
@@ -29,9 +31,22 @@ class CountMinSketchSuite extends FunSuite { // 
scalastyle:ignore funsuite
 
   private val seed = 42
 
+  // Serializes and deserializes a given `CountMinSketch`, then checks whether 
the deserialized
+  // version is equivalent to the original one.
+  private def checkSerDe(sketch: CountMinSketch): Unit = {
+    val out = new ByteArrayOutputStream()
+    sketch.writeTo(out)
+
+    val in = new ByteArrayInputStream(out.toByteArray)
+    val deserialized = CountMinSketch.readFrom(in)
+
+    assert(sketch === deserialized)
+  }
+
   def testAccuracy[T: ClassTag](typeName: String)(itemGenerator: Random => T): 
Unit = {
     test(s"accuracy - $typeName") {
-      val r = new Random()
+      // Uses fixed seed to ensure reproducible test execution
+      val r = new Random(31)
 
       val numAllItems = 1000000
       val allItems = Array.fill(numAllItems)(itemGenerator(r))
@@ -45,7 +60,10 @@ class CountMinSketchSuite extends FunSuite { // 
scalastyle:ignore funsuite
       }
 
       val sketch = CountMinSketch.create(epsOfTotalCount, confidence, seed)
+      checkSerDe(sketch)
+
       sampledItemIndices.foreach(i => sketch.add(allItems(i)))
+      checkSerDe(sketch)
 
       val probCorrect = {
         val numErrors = allItems.map { item =>
@@ -66,7 +84,9 @@ class CountMinSketchSuite extends FunSuite { // 
scalastyle:ignore funsuite
 
   def testMergeInPlace[T: ClassTag](typeName: String)(itemGenerator: Random => 
T): Unit = {
     test(s"mergeInPlace - $typeName") {
-      val r = new Random()
+      // Uses fixed seed to ensure reproducible test execution
+      val r = new Random(31)
+
       val numToMerge = 5
       val numItemsPerSketch = 100000
       val perSketchItems = Array.fill(numToMerge, numItemsPerSketch) {
@@ -75,11 +95,16 @@ class CountMinSketchSuite extends FunSuite { // 
scalastyle:ignore funsuite
 
       val sketches = perSketchItems.map { items =>
         val sketch = CountMinSketch.create(epsOfTotalCount, confidence, seed)
+        checkSerDe(sketch)
+
         items.foreach(sketch.add)
+        checkSerDe(sketch)
+
         sketch
       }
 
       val mergedSketch = sketches.reduce(_ mergeInPlace _)
+      checkSerDe(mergedSketch)
 
       val expectedSketch = {
         val sketch = CountMinSketch.create(epsOfTotalCount, confidence, seed)
@@ -109,4 +134,22 @@ class CountMinSketchSuite extends FunSuite { // 
scalastyle:ignore funsuite
   testItemType[Long]("Long") { _.nextLong() }
 
   testItemType[String]("String") { r => r.nextString(r.nextInt(20)) }
+
+  test("incompatible merge") {
+    intercept[IncompatibleMergeException] {
+      CountMinSketch.create(10, 10, 1).mergeInPlace(null)
+    }
+
+    intercept[IncompatibleMergeException] {
+      val sketch1 = CountMinSketch.create(10, 20, 1)
+      val sketch2 = CountMinSketch.create(10, 20, 2)
+      sketch1.mergeInPlace(sketch2)
+    }
+
+    intercept[IncompatibleMergeException] {
+      val sketch1 = CountMinSketch.create(10, 10, 1)
+      val sketch2 = CountMinSketch.create(10, 20, 2)
+      sketch1.mergeInPlace(sketch2)
+    }
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to