Repository: spark Updated Branches: refs/heads/master 135ff16a3 -> c685b5f56
[SPARK-24411][SQL] Adding native Java tests for 'isInCollection' ## What changes were proposed in this pull request? `JavaColumnExpressionSuite.java` was added and `org.apache.spark.sql.ColumnExpressionSuite#test("isInCollection: Java Collection")` was removed. It provides native Java tests for the method `org.apache.spark.sql.Column#isInCollection`. Closes #22253 from aai95/isInCollectionJavaTest. Authored-by: aai95 <aa...@yandex.ru> Signed-off-by: DB Tsai <d_t...@apple.com> Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/c685b5f5 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/c685b5f5 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/c685b5f5 Branch: refs/heads/master Commit: c685b5f56a69abdf77e07e852b9bb2c6f2e715c9 Parents: 135ff16 Author: aai95 <aa...@yandex.ru> Authored: Thu Aug 30 20:38:03 2018 +0000 Committer: DB Tsai <d_t...@apple.com> Committed: Thu Aug 30 20:38:03 2018 +0000 ---------------------------------------------------------------------- .../spark/sql/JavaColumnExpressionSuite.java | 95 ++++++++++++++++++++ .../spark/sql/ColumnExpressionSuite.scala | 21 ----- 2 files changed, 95 insertions(+), 21 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/c685b5f5/sql/core/src/test/java/test/org/apache/spark/sql/JavaColumnExpressionSuite.java ---------------------------------------------------------------------- diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaColumnExpressionSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaColumnExpressionSuite.java new file mode 100644 index 0000000..38d606c --- /dev/null +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaColumnExpressionSuite.java @@ -0,0 +1,95 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package test.org.apache.spark.sql; + +import org.apache.spark.api.java.function.FilterFunction; +import org.apache.spark.sql.Column; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.test.TestSparkSession; +import org.apache.spark.sql.types.StructType; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import java.util.*; + +import static org.apache.spark.sql.types.DataTypes.*; + +public class JavaColumnExpressionSuite { + private transient TestSparkSession spark; + + @Before + public void setUp() { + spark = new TestSparkSession(); + } + + @After + public void tearDown() { + spark.stop(); + spark = null; + } + + @Test + public void isInCollectionWorksCorrectlyOnJava() { + List<Row> rows = Arrays.asList( + RowFactory.create(1, "x"), + RowFactory.create(2, "y"), + RowFactory.create(3, "z")); + StructType schema = createStructType(Arrays.asList( + createStructField("a", IntegerType, false), + createStructField("b", StringType, false))); + Dataset<Row> df = spark.createDataFrame(rows, schema); + // Test with different types of collections + Assert.assertTrue(Arrays.equals( + (Row[]) df.filter(df.col("a").isInCollection(Arrays.asList(1, 2))).collect(), + (Row[]) df.filter((FilterFunction<Row>) r -> r.getInt(0) == 1 || r.getInt(0) == 2).collect() + )); + Assert.assertTrue(Arrays.equals( + (Row[]) df.filter(df.col("a").isInCollection(new HashSet<>(Arrays.asList(1, 2)))).collect(), + (Row[]) df.filter((FilterFunction<Row>) r -> r.getInt(0) == 1 || r.getInt(0) == 2).collect() + )); + Assert.assertTrue(Arrays.equals( + (Row[]) df.filter(df.col("a").isInCollection(new ArrayList<>(Arrays.asList(3, 1)))).collect(), + (Row[]) df.filter((FilterFunction<Row>) r -> r.getInt(0) == 3 || r.getInt(0) == 1).collect() + )); + } + + @Test + public void isInCollectionCheckExceptionMessage() { + List<Row> rows = Arrays.asList( + RowFactory.create(1, Arrays.asList(1)), + RowFactory.create(2, Arrays.asList(2)), + RowFactory.create(3, Arrays.asList(3))); + StructType schema = createStructType(Arrays.asList( + createStructField("a", IntegerType, false), + createStructField("b", createArrayType(IntegerType, false), false))); + Dataset<Row> df = spark.createDataFrame(rows, schema); + try { + df.filter(df.col("a").isInCollection(Arrays.asList(new Column("b")))); + Assert.fail("Expected org.apache.spark.sql.AnalysisException"); + } catch (Exception e) { + Arrays.asList("cannot resolve", + "due to data type mismatch: Arguments must be same type but were") + .forEach(s -> Assert.assertTrue( + e.getMessage().toLowerCase(Locale.ROOT).contains(s.toLowerCase(Locale.ROOT)))); + } + } +} http://git-wip-us.apache.org/repos/asf/spark/blob/c685b5f5/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index 2182bd7..2917c56 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -436,27 +436,6 @@ class ColumnExpressionSuite extends QueryTest with SharedSQLContext { } } - test("isInCollection: Java Collection") { - val df = Seq((1, "x"), (2, "y"), (3, "z")).toDF("a", "b") - // Test with different types of collections - checkAnswer(df.filter($"a".isInCollection(Seq(1, 2).asJava)), - df.collect().toSeq.filter(r => r.getInt(0) == 1 || r.getInt(0) == 2)) - checkAnswer(df.filter($"a".isInCollection(Seq(1, 2).toSet.asJava)), - df.collect().toSeq.filter(r => r.getInt(0) == 1 || r.getInt(0) == 2)) - checkAnswer(df.filter($"a".isInCollection(Seq(3, 1).toList.asJava)), - df.collect().toSeq.filter(r => r.getInt(0) == 3 || r.getInt(0) == 1)) - - val df2 = Seq((1, Seq(1)), (2, Seq(2)), (3, Seq(3))).toDF("a", "b") - - val e = intercept[AnalysisException] { - df2.filter($"a".isInCollection(Seq($"b").asJava)) - } - Seq("cannot resolve", "due to data type mismatch: Arguments must be same type but were") - .foreach { s => - assert(e.getMessage.toLowerCase(Locale.ROOT).contains(s.toLowerCase(Locale.ROOT))) - } - } - test("&&") { checkAnswer( booleanData.filter($"a" && true), --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org