Github user viirya commented on a diff in the pull request: https://github.com/apache/spark/pull/22253#discussion_r213864459 --- Diff: sql/core/src/test/java/test/org/apache/spark/sql/JavaColumnExpressionSuite.java --- @@ -0,0 +1,80 @@ +package test.org.apache.spark.sql; + +import org.apache.spark.api.java.function.FilterFunction; +import org.apache.spark.sql.Column; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.test.TestSparkSession; +import org.apache.spark.sql.types.StructType; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import java.util.*; + +import static org.apache.spark.sql.types.DataTypes.*; + +public class JavaColumnExpressionSuite { + + private transient TestSparkSession spark; + + @Before + public void setUp() { + spark = new TestSparkSession(); + } + + @After + public void tearDown() { + spark.stop(); + spark = null; + } + + @Test + public void isInCollectionWorksCorrectlyOnJava() { + List<Row> rows = Arrays.asList( + RowFactory.create(1, "x"), + RowFactory.create(2, "y"), + RowFactory.create(3, "z") + ); + StructType schema = createStructType(Arrays.asList( + createStructField("a", IntegerType, false), + createStructField("b", StringType, false) + )); + Dataset<Row> df = spark.createDataFrame(rows, schema); + // Test with different types of collections + Assert.assertTrue(Arrays.equals( + (Row[]) df.filter(df.col("a").isInCollection(Arrays.asList(1, 2))).collect(), + (Row[]) df.filter((FilterFunction<Row>) r -> r.getInt(0) == 1 || r.getInt(0) == 2).collect() + )); + Assert.assertTrue(Arrays.equals( + (Row[]) df.filter(df.col("a").isInCollection(new HashSet<>(Arrays.asList(1, 2)))).collect(), + (Row[]) df.filter((FilterFunction<Row>) r -> r.getInt(0) == 1 || r.getInt(0) == 2).collect() + )); + Assert.assertTrue(Arrays.equals( + (Row[]) df.filter(df.col("a").isInCollection(new ArrayList<>(Arrays.asList(3, 1)))).collect(), + (Row[]) df.filter((FilterFunction<Row>) r -> r.getInt(0) == 3 || r.getInt(0) == 1).collect() + )); + } + + @Test + public void isInCollectionThrowsExceptionWithCorrectMessageOnJava() { --- End diff -- Can we shorten this method name? E.g., `checkExceptionMessage`?
--- --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org