Repository: spark Updated Branches: refs/heads/master e2ae7bd04 -> 962e9bcf9
[SPARK-12756][SQL] use hash expression in Exchange This PR makes bucketing and exchange share one common hash algorithm, so that we can guarantee the data distribution is same between shuffle and bucketed data source, which enables us to only shuffle one side when join a bucketed table and a normal one. This PR also fixes the tests that are broken by the new hash behaviour in shuffle. Author: Wenchen Fan <wenc...@databricks.com> Closes #10703 from cloud-fan/use-hash-expr-in-shuffle. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/962e9bcf Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/962e9bcf Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/962e9bcf Branch: refs/heads/master Commit: 962e9bcf94da6f5134983f2bf1e56c5cd84f2bf7 Parents: e2ae7bd Author: Wenchen Fan <wenc...@databricks.com> Authored: Wed Jan 13 22:43:28 2016 -0800 Committer: Reynold Xin <r...@databricks.com> Committed: Wed Jan 13 22:43:28 2016 -0800 ---------------------------------------------------------------------- R/pkg/inst/tests/testthat/test_sparkSQL.R | 2 +- python/pyspark/sql/dataframe.py | 26 +++++++-------- python/pyspark/sql/group.py | 6 ++-- .../catalyst/plans/physical/partitioning.scala | 7 ++++- .../apache/spark/sql/execution/Exchange.scala | 12 +++++-- .../execution/datasources/WriterContainer.scala | 20 ++++++------ .../apache/spark/sql/JavaDataFrameSuite.java | 4 +-- .../org/apache/spark/sql/JavaDatasetSuite.java | 33 +++++++++++--------- .../org/apache/spark/sql/DataFrameSuite.scala | 21 +++++++------ .../org/apache/spark/sql/DatasetSuite.scala | 4 +-- .../org/apache/spark/sql/SQLQuerySuite.scala | 2 +- .../spark/sql/sources/BucketedWriteSuite.scala | 11 ++++--- 12 files changed, 84 insertions(+), 64 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/962e9bcf/R/pkg/inst/tests/testthat/test_sparkSQL.R ---------------------------------------------------------------------- diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R index 97625b9..40d5066 100644 --- a/R/pkg/inst/tests/testthat/test_sparkSQL.R +++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R @@ -1173,7 +1173,7 @@ test_that("group by, agg functions", { expect_equal(3, count(mean(gd))) expect_equal(3, count(max(gd))) - expect_equal(30, collect(max(gd))[1, 2]) + expect_equal(30, collect(max(gd))[2, 2]) expect_equal(1, collect(count(gd))[1, 2]) mockLines2 <- c("{\"name\":\"ID1\", \"value\": \"10\"}", http://git-wip-us.apache.org/repos/asf/spark/blob/962e9bcf/python/pyspark/sql/dataframe.py ---------------------------------------------------------------------- diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index a7bc288..90a6b5d 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -403,10 +403,10 @@ class DataFrame(object): +---+-----+ |age| name| +---+-----+ - | 2|Alice| - | 2|Alice| | 5| Bob| | 5| Bob| + | 2|Alice| + | 2|Alice| +---+-----+ >>> data = data.repartition(7, "age") >>> data.show() @@ -552,7 +552,7 @@ class DataFrame(object): >>> df_as2 = df.alias("df_as2") >>> joined_df = df_as1.join(df_as2, col("df_as1.name") == col("df_as2.name"), 'inner') >>> joined_df.select(col("df_as1.name"), col("df_as2.name"), col("df_as2.age")).collect() - [Row(name=u'Alice', name=u'Alice', age=2), Row(name=u'Bob', name=u'Bob', age=5)] + [Row(name=u'Bob', name=u'Bob', age=5), Row(name=u'Alice', name=u'Alice', age=2)] """ assert isinstance(alias, basestring), "alias should be a string" return DataFrame(getattr(self._jdf, "as")(alias), self.sql_ctx) @@ -573,14 +573,14 @@ class DataFrame(object): One of `inner`, `outer`, `left_outer`, `right_outer`, `leftsemi`. >>> df.join(df2, df.name == df2.name, 'outer').select(df.name, df2.height).collect() - [Row(name=None, height=80), Row(name=u'Alice', height=None), Row(name=u'Bob', height=85)] + [Row(name=None, height=80), Row(name=u'Bob', height=85), Row(name=u'Alice', height=None)] >>> df.join(df2, 'name', 'outer').select('name', 'height').collect() - [Row(name=u'Tom', height=80), Row(name=u'Alice', height=None), Row(name=u'Bob', height=85)] + [Row(name=u'Tom', height=80), Row(name=u'Bob', height=85), Row(name=u'Alice', height=None)] >>> cond = [df.name == df3.name, df.age == df3.age] >>> df.join(df3, cond, 'outer').select(df.name, df3.age).collect() - [Row(name=u'Bob', age=5), Row(name=u'Alice', age=2)] + [Row(name=u'Alice', age=2), Row(name=u'Bob', age=5)] >>> df.join(df2, 'name').select(df.name, df2.height).collect() [Row(name=u'Bob', height=85)] @@ -880,9 +880,9 @@ class DataFrame(object): >>> df.groupBy().avg().collect() [Row(avg(age)=3.5)] - >>> df.groupBy('name').agg({'age': 'mean'}).collect() + >>> sorted(df.groupBy('name').agg({'age': 'mean'}).collect()) [Row(name=u'Alice', avg(age)=2.0), Row(name=u'Bob', avg(age)=5.0)] - >>> df.groupBy(df.name).avg().collect() + >>> sorted(df.groupBy(df.name).avg().collect()) [Row(name=u'Alice', avg(age)=2.0), Row(name=u'Bob', avg(age)=5.0)] >>> df.groupBy(['name', df.age]).count().collect() [Row(name=u'Bob', age=5, count=1), Row(name=u'Alice', age=2, count=1)] @@ -901,11 +901,11 @@ class DataFrame(object): +-----+----+-----+ | name| age|count| +-----+----+-----+ - |Alice|null| 1| + |Alice| 2| 1| | Bob| 5| 1| | Bob|null| 1| | null|null| 2| - |Alice| 2| 1| + |Alice|null| 1| +-----+----+-----+ """ jgd = self._jdf.rollup(self._jcols(*cols)) @@ -923,12 +923,12 @@ class DataFrame(object): | name| age|count| +-----+----+-----+ | null| 2| 1| - |Alice|null| 1| + |Alice| 2| 1| | Bob| 5| 1| - | Bob|null| 1| | null| 5| 1| + | Bob|null| 1| | null|null| 2| - |Alice| 2| 1| + |Alice|null| 1| +-----+----+-----+ """ jgd = self._jdf.cube(self._jcols(*cols)) http://git-wip-us.apache.org/repos/asf/spark/blob/962e9bcf/python/pyspark/sql/group.py ---------------------------------------------------------------------- diff --git a/python/pyspark/sql/group.py b/python/pyspark/sql/group.py index 9ca303a..ee734cb 100644 --- a/python/pyspark/sql/group.py +++ b/python/pyspark/sql/group.py @@ -74,11 +74,11 @@ class GroupedData(object): or a list of :class:`Column`. >>> gdf = df.groupBy(df.name) - >>> gdf.agg({"*": "count"}).collect() + >>> sorted(gdf.agg({"*": "count"}).collect()) [Row(name=u'Alice', count(1)=1), Row(name=u'Bob', count(1)=1)] >>> from pyspark.sql import functions as F - >>> gdf.agg(F.min(df.age)).collect() + >>> sorted(gdf.agg(F.min(df.age)).collect()) [Row(name=u'Alice', min(age)=2), Row(name=u'Bob', min(age)=5)] """ assert exprs, "exprs should not be empty" @@ -96,7 +96,7 @@ class GroupedData(object): def count(self): """Counts the number of records for each group. - >>> df.groupBy(df.age).count().collect() + >>> sorted(df.groupBy(df.age).count().collect()) [Row(age=2, count=1), Row(age=5, count=1)] """ http://git-wip-us.apache.org/repos/asf/spark/blob/962e9bcf/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index 1bfe0ec..d6e10c4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.plans.physical -import org.apache.spark.sql.catalyst.expressions.{Expression, SortOrder, Unevaluable} +import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types.{DataType, IntegerType} /** @@ -249,6 +249,11 @@ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int) case _ => false } + /** + * Returns an expression that will produce a valid partition ID(i.e. non-negative and is less + * than numPartitions) based on hashing expressions. + */ + def partitionIdExpression: Expression = Pmod(new Murmur3Hash(expressions), Literal(numPartitions)) } /** http://git-wip-us.apache.org/repos/asf/spark/blob/962e9bcf/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala index 058d147..3770883 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala @@ -143,7 +143,13 @@ case class Exchange( val rdd = child.execute() val part: Partitioner = newPartitioning match { case RoundRobinPartitioning(numPartitions) => new HashPartitioner(numPartitions) - case HashPartitioning(expressions, numPartitions) => new HashPartitioner(numPartitions) + case HashPartitioning(_, n) => + new Partitioner { + override def numPartitions: Int = n + // For HashPartitioning, the partitioning key is already a valid partition ID, as we use + // `HashPartitioning.partitionIdExpression` to produce partitioning key. + override def getPartition(key: Any): Int = key.asInstanceOf[Int] + } case RangePartitioning(sortingExpressions, numPartitions) => // Internally, RangePartitioner runs a job on the RDD that samples keys to compute // partition bounds. To get accurate samples, we need to copy the mutable keys. @@ -173,7 +179,9 @@ case class Exchange( position += 1 position } - case HashPartitioning(expressions, _) => newMutableProjection(expressions, child.output)() + case h: HashPartitioning => + val projection = UnsafeProjection.create(h.partitionIdExpression :: Nil, child.output) + row => projection(row).getInt(0) case RangePartitioning(_, _) | SinglePartition => identity case _ => sys.error(s"Exchange not implemented for $newPartitioning") } http://git-wip-us.apache.org/repos/asf/spark/blob/962e9bcf/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala index fff7287..fc77529 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala @@ -19,8 +19,6 @@ package org.apache.spark.sql.execution.datasources import java.util.{Date, UUID} -import scala.collection.JavaConverters._ - import org.apache.hadoop.fs.Path import org.apache.hadoop.mapreduce._ import org.apache.hadoop.mapreduce.lib.output.{FileOutputCommitter => MapReduceFileOutputCommitter} @@ -30,6 +28,7 @@ import org.apache.spark._ import org.apache.spark.mapred.SparkHadoopMapRedUtil import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.execution.UnsafeKVExternalSorter import org.apache.spark.sql.sources.{HadoopFsRelation, OutputWriter, OutputWriterFactory} @@ -322,9 +321,12 @@ private[sql] class DynamicPartitionWriterContainer( spec => spec.sortColumnNames.map(c => inputSchema.find(_.name == c).get) } - private def bucketIdExpression: Option[Expression] = for { - BucketSpec(numBuckets, _, _) <- bucketSpec - } yield Pmod(new Murmur3Hash(bucketColumns), Literal(numBuckets)) + private def bucketIdExpression: Option[Expression] = bucketSpec.map { spec => + // Use `HashPartitioning.partitionIdExpression` as our bucket id expression, so that we can + // guarantee the data distribution is same between shuffle and bucketed data source, which + // enables us to only shuffle one side when join a bucketed table and a normal one. + HashPartitioning(bucketColumns, spec.numBuckets).partitionIdExpression + } // Expressions that given a partition key build a string like: col1=val/col2=val/... private def partitionStringExpression: Seq[Expression] = { @@ -341,12 +343,8 @@ private[sql] class DynamicPartitionWriterContainer( } } - private def getBucketIdFromKey(key: InternalRow): Option[Int] = { - if (bucketSpec.isDefined) { - Some(key.getInt(partitionColumns.length)) - } else { - None - } + private def getBucketIdFromKey(key: InternalRow): Option[Int] = bucketSpec.map { _ => + key.getInt(partitionColumns.length) } /** http://git-wip-us.apache.org/repos/asf/spark/blob/962e9bcf/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java ---------------------------------------------------------------------- diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java index 8e0b2db..ac1607b 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java @@ -237,8 +237,8 @@ public class JavaDataFrameSuite { DataFrame crosstab = df.stat().crosstab("a", "b"); String[] columnNames = crosstab.schema().fieldNames(); Assert.assertEquals("a_b", columnNames[0]); - Assert.assertEquals("1", columnNames[1]); - Assert.assertEquals("2", columnNames[2]); + Assert.assertEquals("2", columnNames[1]); + Assert.assertEquals("1", columnNames[2]); Row[] rows = crosstab.collect(); Arrays.sort(rows, crosstabRowComparator); Integer count = 1; http://git-wip-us.apache.org/repos/asf/spark/blob/962e9bcf/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java ---------------------------------------------------------------------- diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java index 9f8db39..1a3df1b 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java @@ -187,7 +187,7 @@ public class JavaDatasetSuite implements Serializable { } }, Encoders.STRING()); - Assert.assertEquals(Arrays.asList("1a", "3foobar"), mapped.collectAsList()); + Assert.assertEquals(asSet("1a", "3foobar"), toSet(mapped.collectAsList())); Dataset<String> flatMapped = grouped.flatMapGroups( new FlatMapGroupsFunction<Integer, String, String>() { @@ -202,7 +202,7 @@ public class JavaDatasetSuite implements Serializable { }, Encoders.STRING()); - Assert.assertEquals(Arrays.asList("1a", "3foobar"), flatMapped.collectAsList()); + Assert.assertEquals(asSet("1a", "3foobar"), toSet(flatMapped.collectAsList())); Dataset<Tuple2<Integer, String>> reduced = grouped.reduce(new ReduceFunction<String>() { @Override @@ -212,8 +212,8 @@ public class JavaDatasetSuite implements Serializable { }); Assert.assertEquals( - Arrays.asList(tuple2(1, "a"), tuple2(3, "foobar")), - reduced.collectAsList()); + asSet(tuple2(1, "a"), tuple2(3, "foobar")), + toSet(reduced.collectAsList())); List<Integer> data2 = Arrays.asList(2, 6, 10); Dataset<Integer> ds2 = context.createDataset(data2, Encoders.INT()); @@ -245,7 +245,7 @@ public class JavaDatasetSuite implements Serializable { }, Encoders.STRING()); - Assert.assertEquals(Arrays.asList("1a#2", "3foobar#6", "5#10"), cogrouped.collectAsList()); + Assert.assertEquals(asSet("1a#2", "3foobar#6", "5#10"), toSet(cogrouped.collectAsList())); } @Test @@ -268,7 +268,7 @@ public class JavaDatasetSuite implements Serializable { }, Encoders.STRING()); - Assert.assertEquals(Arrays.asList("1a", "3foobar"), mapped.collectAsList()); + Assert.assertEquals(asSet("1a", "3foobar"), toSet(mapped.collectAsList())); } @Test @@ -290,9 +290,7 @@ public class JavaDatasetSuite implements Serializable { List<String> data = Arrays.asList("abc", "abc", "xyz"); Dataset<String> ds = context.createDataset(data, Encoders.STRING()); - Assert.assertEquals( - Arrays.asList("abc", "xyz"), - sort(ds.distinct().collectAsList().toArray(new String[0]))); + Assert.assertEquals(asSet("abc", "xyz"), toSet(ds.distinct().collectAsList())); List<String> data2 = Arrays.asList("xyz", "foo", "foo"); Dataset<String> ds2 = context.createDataset(data2, Encoders.STRING()); @@ -302,16 +300,23 @@ public class JavaDatasetSuite implements Serializable { Dataset<String> unioned = ds.union(ds2); Assert.assertEquals( - Arrays.asList("abc", "abc", "foo", "foo", "xyz", "xyz"), - sort(unioned.collectAsList().toArray(new String[0]))); + Arrays.asList("abc", "abc", "xyz", "xyz", "foo", "foo"), + unioned.collectAsList()); Dataset<String> subtracted = ds.subtract(ds2); Assert.assertEquals(Arrays.asList("abc", "abc"), subtracted.collectAsList()); } - private <T extends Comparable<T>> List<T> sort(T[] data) { - Arrays.sort(data); - return Arrays.asList(data); + private <T> Set<T> toSet(List<T> records) { + Set<T> set = new HashSet<T>(); + for (T record : records) { + set.add(record); + } + return set; + } + + private <T> Set<T> asSet(T... records) { + return toSet(Arrays.asList(records)); } @Test http://git-wip-us.apache.org/repos/asf/spark/blob/962e9bcf/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 983dfbd..d6c140d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -1083,17 +1083,20 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { // Walk each partition and verify that it is sorted descending and does not contain all // the values. df4.rdd.foreachPartition { p => - var previousValue: Int = -1 - var allSequential: Boolean = true - p.foreach { r => - val v: Int = r.getInt(1) - if (previousValue != -1) { - if (previousValue < v) throw new SparkException("Partition is not ordered.") - if (v + 1 != previousValue) allSequential = false + // Skip empty partition + if (p.hasNext) { + var previousValue: Int = -1 + var allSequential: Boolean = true + p.foreach { r => + val v: Int = r.getInt(1) + if (previousValue != -1) { + if (previousValue < v) throw new SparkException("Partition is not ordered.") + if (v + 1 != previousValue) allSequential = false + } + previousValue = v } - previousValue = v + if (allSequential) throw new SparkException("Partition should not be globally ordered") } - if (allSequential) throw new SparkException("Partition should not be globally ordered") } // Distribute and order by with multiple order bys http://git-wip-us.apache.org/repos/asf/spark/blob/962e9bcf/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 693f5ae..d7b86e3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -456,8 +456,8 @@ class DatasetSuite extends QueryTest with SharedSQLContext { implicit val kryoEncoder = Encoders.kryo[KryoData] val ds = Seq(KryoData(1), KryoData(2)).toDS() - assert(ds.groupBy(p => p).count().collect().toSeq == - Seq((KryoData(1), 1L), (KryoData(2), 1L))) + assert(ds.groupBy(p => p).count().collect().toSet == + Set((KryoData(1), 1L), (KryoData(2), 1L))) } test("Kryo encoder self join") { http://git-wip-us.apache.org/repos/asf/spark/blob/962e9bcf/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 5de0979..03d67c4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -806,7 +806,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { sql("SELECT DISTINCT n FROM lowerCaseData ORDER BY n DESC") .limit(2) .registerTempTable("subset1") - sql("SELECT DISTINCT n FROM lowerCaseData") + sql("SELECT DISTINCT n FROM lowerCaseData ORDER BY n ASC") .limit(2) .registerTempTable("subset2") checkAnswer( http://git-wip-us.apache.org/repos/asf/spark/blob/962e9bcf/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala ---------------------------------------------------------------------- diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala index b718b7c..3ea9826 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala @@ -20,11 +20,11 @@ package org.apache.spark.sql.sources import java.io.File import org.apache.spark.sql.{AnalysisException, QueryTest} -import org.apache.spark.sql.catalyst.expressions.{Murmur3Hash, UnsafeProjection, UnsafeRow} +import org.apache.spark.sql.catalyst.expressions.UnsafeProjection +import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning import org.apache.spark.sql.functions._ import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.test.SQLTestUtils -import org.apache.spark.util.Utils class BucketedWriteSuite extends QueryTest with SQLTestUtils with TestHiveSingleton { import testImplicits._ @@ -98,11 +98,12 @@ class BucketedWriteSuite extends QueryTest with SQLTestUtils with TestHiveSingle val qe = readBack.select(bucketCols.map(col): _*).queryExecution val rows = qe.toRdd.map(_.copy()).collect() - val getHashCode = - UnsafeProjection.create(new Murmur3Hash(qe.analyzed.output) :: Nil, qe.analyzed.output) + val getHashCode = UnsafeProjection.create( + HashPartitioning(qe.analyzed.output, 8).partitionIdExpression :: Nil, + qe.analyzed.output) for (row <- rows) { - val actualBucketId = Utils.nonNegativeMod(getHashCode(row).getInt(0), 8) + val actualBucketId = getHashCode(row).getInt(0) assert(actualBucketId == bucketId) } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org