Repository: spark Updated Branches: refs/heads/master dcae355c6 -> 6f0f1d9e0
[SPARK-12934][SQL] Count-min sketch serialization This PR adds serialization support for `CountMinSketch`. A version number is added to version the serialized binary format. Author: Cheng Lian <l...@databricks.com> Closes #10893 from liancheng/cms-serialization. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/6f0f1d9e Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/6f0f1d9e Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/6f0f1d9e Branch: refs/heads/master Commit: 6f0f1d9e04a8db47e2f6f8fcfe9dea9de0f633da Parents: dcae355 Author: Cheng Lian <l...@databricks.com> Authored: Mon Jan 25 15:05:05 2016 -0800 Committer: Reynold Xin <r...@databricks.com> Committed: Mon Jan 25 15:05:05 2016 -0800 ---------------------------------------------------------------------- .../spark/util/sketch/CountMinSketch.java | 32 ++++- .../spark/util/sketch/CountMinSketchImpl.java | 129 +++++++++++++++++-- .../util/sketch/IncompatibleMergeException.java | 24 ++++ .../spark/util/sketch/CountMinSketchSuite.scala | 47 ++++++- 4 files changed, 213 insertions(+), 19 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/6f0f1d9e/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java ---------------------------------------------------------------------- diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java index 21b161b..6793864 100644 --- a/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java +++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java @@ -17,6 +17,7 @@ package org.apache.spark.util.sketch; +import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; @@ -55,6 +56,25 @@ import java.io.OutputStream; */ abstract public class CountMinSketch { /** + * Version number of the serialized binary format. + */ + public enum Version { + V1(1); + + private final int versionNumber; + + Version(int versionNumber) { + this.versionNumber = versionNumber; + } + + public int getVersionNumber() { + return versionNumber; + } + } + + public abstract Version version(); + + /** * Returns the relative error (or {@code eps}) of this {@link CountMinSketch}. */ public abstract double relativeError(); @@ -99,19 +119,23 @@ abstract public class CountMinSketch { * * Note that only Count-Min sketches with the same {@code depth}, {@code width}, and random seed * can be merged. + * + * @exception IncompatibleMergeException if the {@code other} {@link CountMinSketch} has + * incompatible depth, width, relative-error, confidence, or random seed. */ - public abstract CountMinSketch mergeInPlace(CountMinSketch other); + public abstract CountMinSketch mergeInPlace(CountMinSketch other) + throws IncompatibleMergeException; /** * Writes out this {@link CountMinSketch} to an output stream in binary format. */ - public abstract void writeTo(OutputStream out); + public abstract void writeTo(OutputStream out) throws IOException; /** * Reads in a {@link CountMinSketch} from an input stream. */ - public static CountMinSketch readFrom(InputStream in) { - throw new UnsupportedOperationException("Not implemented yet"); + public static CountMinSketch readFrom(InputStream in) throws IOException { + return CountMinSketchImpl.readFrom(in); } /** http://git-wip-us.apache.org/repos/asf/spark/blob/6f0f1d9e/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java ---------------------------------------------------------------------- diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java index e9fdbe3..0209446 100644 --- a/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java +++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java @@ -17,11 +17,30 @@ package org.apache.spark.util.sketch; +import java.io.DataInputStream; +import java.io.DataOutputStream; +import java.io.IOException; +import java.io.InputStream; import java.io.OutputStream; import java.io.UnsupportedEncodingException; import java.util.Arrays; import java.util.Random; +/* + * Binary format of a serialized CountMinSketchImpl, version 1 (all values written in big-endian + * order): + * + * - Version number, always 1 (32 bit) + * - Total count of added items (64 bit) + * - Depth (32 bit) + * - Width (32 bit) + * - Hash functions (depth * 64 bit) + * - Count table + * - Row 0 (width * 64 bit) + * - Row 1 (width * 64 bit) + * - ... + * - Row depth - 1 (width * 64 bit) + */ class CountMinSketchImpl extends CountMinSketch { public static final long PRIME_MODULUS = (1L << 31) - 1; @@ -33,7 +52,7 @@ class CountMinSketchImpl extends CountMinSketch { private double eps; private double confidence; - public CountMinSketchImpl(int depth, int width, int seed) { + CountMinSketchImpl(int depth, int width, int seed) { this.depth = depth; this.width = width; this.eps = 2.0 / width; @@ -41,7 +60,7 @@ class CountMinSketchImpl extends CountMinSketch { initTablesWith(depth, width, seed); } - public CountMinSketchImpl(double eps, double confidence, int seed) { + CountMinSketchImpl(double eps, double confidence, int seed) { // 2/w = eps ; w = 2/eps // 1/2^depth <= 1-confidence ; depth >= -log2 (1-confidence) this.eps = eps; @@ -51,6 +70,53 @@ class CountMinSketchImpl extends CountMinSketch { initTablesWith(depth, width, seed); } + CountMinSketchImpl(int depth, int width, long totalCount, long hashA[], long table[][]) { + this.depth = depth; + this.width = width; + this.eps = 2.0 / width; + this.confidence = 1 - 1 / Math.pow(2, depth); + this.hashA = hashA; + this.table = table; + this.totalCount = totalCount; + } + + @Override + public boolean equals(Object other) { + if (other == this) { + return true; + } + + if (other == null || !(other instanceof CountMinSketchImpl)) { + return false; + } + + CountMinSketchImpl that = (CountMinSketchImpl) other; + + return + this.depth == that.depth && + this.width == that.width && + this.totalCount == that.totalCount && + Arrays.equals(this.hashA, that.hashA) && + Arrays.deepEquals(this.table, that.table); + } + + @Override + public int hashCode() { + int hash = depth; + + hash = hash * 31 + width; + hash = hash * 31 + (int) (totalCount ^ (totalCount >>> 32)); + hash = hash * 31 + Arrays.hashCode(hashA); + hash = hash * 31 + Arrays.deepHashCode(table); + + return hash; + } + + @Override + public Version version() { + return Version.V1; + } + private void initTablesWith(int depth, int width, int seed) { this.table = new long[depth][width]; this.hashA = new long[depth]; @@ -221,27 +287,29 @@ class CountMinSketchImpl extends CountMinSketch { } @Override - public CountMinSketch mergeInPlace(CountMinSketch other) { + public CountMinSketch mergeInPlace(CountMinSketch other) throws IncompatibleMergeException { if (other == null) { - throw new CMSMergeException("Cannot merge null estimator"); + throw new IncompatibleMergeException("Cannot merge null estimator"); } if (!(other instanceof CountMinSketchImpl)) { - throw new CMSMergeException("Cannot merge estimator of class " + other.getClass().getName()); + throw new IncompatibleMergeException( + "Cannot merge estimator of class " + other.getClass().getName() + ); } CountMinSketchImpl that = (CountMinSketchImpl) other; if (this.depth != that.depth) { - throw new CMSMergeException("Cannot merge estimators of different depth"); + throw new IncompatibleMergeException("Cannot merge estimators of different depth"); } if (this.width != that.width) { - throw new CMSMergeException("Cannot merge estimators of different width"); + throw new IncompatibleMergeException("Cannot merge estimators of different width"); } if (!Arrays.equals(this.hashA, that.hashA)) { - throw new CMSMergeException("Cannot merge estimators of different seed"); + throw new IncompatibleMergeException("Cannot merge estimators of different seed"); } for (int i = 0; i < this.table.length; ++i) { @@ -256,13 +324,48 @@ class CountMinSketchImpl extends CountMinSketch { } @Override - public void writeTo(OutputStream out) { - throw new UnsupportedOperationException("Not implemented yet"); + public void writeTo(OutputStream out) throws IOException { + DataOutputStream dos = new DataOutputStream(out); + + dos.writeInt(version().getVersionNumber()); + + dos.writeLong(this.totalCount); + dos.writeInt(this.depth); + dos.writeInt(this.width); + + for (int i = 0; i < this.depth; ++i) { + dos.writeLong(this.hashA[i]); + } + + for (int i = 0; i < this.depth; ++i) { + for (int j = 0; j < this.width; ++j) { + dos.writeLong(table[i][j]); + } + } } - protected static class CMSMergeException extends RuntimeException { - public CMSMergeException(String message) { - super(message); + public static CountMinSketchImpl readFrom(InputStream in) throws IOException { + DataInputStream dis = new DataInputStream(in); + + // Ignores version number + dis.readInt(); + + long totalCount = dis.readLong(); + int depth = dis.readInt(); + int width = dis.readInt(); + + long hashA[] = new long[depth]; + for (int i = 0; i < depth; ++i) { + hashA[i] = dis.readLong(); + } + + long table[][] = new long[depth][width]; + for (int i = 0; i < depth; ++i) { + for (int j = 0; j < width; ++j) { + table[i][j] = dis.readLong(); + } } + + return new CountMinSketchImpl(depth, width, totalCount, hashA, table); } } http://git-wip-us.apache.org/repos/asf/spark/blob/6f0f1d9e/common/sketch/src/main/java/org/apache/spark/util/sketch/IncompatibleMergeException.java ---------------------------------------------------------------------- diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/IncompatibleMergeException.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/IncompatibleMergeException.java new file mode 100644 index 0000000..64b567c --- /dev/null +++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/IncompatibleMergeException.java @@ -0,0 +1,24 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.sketch; + +public class IncompatibleMergeException extends Exception { + public IncompatibleMergeException(String message) { + super(message); + } +} http://git-wip-us.apache.org/repos/asf/spark/blob/6f0f1d9e/common/sketch/src/test/scala/org/apache/spark/util/sketch/CountMinSketchSuite.scala ---------------------------------------------------------------------- diff --git a/common/sketch/src/test/scala/org/apache/spark/util/sketch/CountMinSketchSuite.scala b/common/sketch/src/test/scala/org/apache/spark/util/sketch/CountMinSketchSuite.scala index ec5b4ed..b9c7f5c 100644 --- a/common/sketch/src/test/scala/org/apache/spark/util/sketch/CountMinSketchSuite.scala +++ b/common/sketch/src/test/scala/org/apache/spark/util/sketch/CountMinSketchSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.util.sketch +import java.io.{ByteArrayInputStream, ByteArrayOutputStream} + import scala.reflect.ClassTag import scala.util.Random @@ -29,9 +31,22 @@ class CountMinSketchSuite extends FunSuite { // scalastyle:ignore funsuite private val seed = 42 + // Serializes and deserializes a given `CountMinSketch`, then checks whether the deserialized + // version is equivalent to the original one. + private def checkSerDe(sketch: CountMinSketch): Unit = { + val out = new ByteArrayOutputStream() + sketch.writeTo(out) + + val in = new ByteArrayInputStream(out.toByteArray) + val deserialized = CountMinSketch.readFrom(in) + + assert(sketch === deserialized) + } + def testAccuracy[T: ClassTag](typeName: String)(itemGenerator: Random => T): Unit = { test(s"accuracy - $typeName") { - val r = new Random() + // Uses fixed seed to ensure reproducible test execution + val r = new Random(31) val numAllItems = 1000000 val allItems = Array.fill(numAllItems)(itemGenerator(r)) @@ -45,7 +60,10 @@ class CountMinSketchSuite extends FunSuite { // scalastyle:ignore funsuite } val sketch = CountMinSketch.create(epsOfTotalCount, confidence, seed) + checkSerDe(sketch) + sampledItemIndices.foreach(i => sketch.add(allItems(i))) + checkSerDe(sketch) val probCorrect = { val numErrors = allItems.map { item => @@ -66,7 +84,9 @@ class CountMinSketchSuite extends FunSuite { // scalastyle:ignore funsuite def testMergeInPlace[T: ClassTag](typeName: String)(itemGenerator: Random => T): Unit = { test(s"mergeInPlace - $typeName") { - val r = new Random() + // Uses fixed seed to ensure reproducible test execution + val r = new Random(31) + val numToMerge = 5 val numItemsPerSketch = 100000 val perSketchItems = Array.fill(numToMerge, numItemsPerSketch) { @@ -75,11 +95,16 @@ class CountMinSketchSuite extends FunSuite { // scalastyle:ignore funsuite val sketches = perSketchItems.map { items => val sketch = CountMinSketch.create(epsOfTotalCount, confidence, seed) + checkSerDe(sketch) + items.foreach(sketch.add) + checkSerDe(sketch) + sketch } val mergedSketch = sketches.reduce(_ mergeInPlace _) + checkSerDe(mergedSketch) val expectedSketch = { val sketch = CountMinSketch.create(epsOfTotalCount, confidence, seed) @@ -109,4 +134,22 @@ class CountMinSketchSuite extends FunSuite { // scalastyle:ignore funsuite testItemType[Long]("Long") { _.nextLong() } testItemType[String]("String") { r => r.nextString(r.nextInt(20)) } + + test("incompatible merge") { + intercept[IncompatibleMergeException] { + CountMinSketch.create(10, 10, 1).mergeInPlace(null) + } + + intercept[IncompatibleMergeException] { + val sketch1 = CountMinSketch.create(10, 20, 1) + val sketch2 = CountMinSketch.create(10, 20, 2) + sketch1.mergeInPlace(sketch2) + } + + intercept[IncompatibleMergeException] { + val sketch1 = CountMinSketch.create(10, 10, 1) + val sketch2 = CountMinSketch.create(10, 20, 2) + sketch1.mergeInPlace(sketch2) + } + } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org