nikolamand-db commented on code in PR #45963: URL: https://github.com/apache/spark/pull/45963#discussion_r1596887611
########## common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java: ########## @@ -722,6 +722,65 @@ public static UTF8String execLowercase( } } + /** + * Utility class for collation aware Levenshtein function. + */ + public static class Levenshtein{ Review Comment: ```suggestion public static class Levenshtein { ``` ########## common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java: ########## @@ -722,6 +722,65 @@ public static UTF8String execLowercase( } } + /** + * Utility class for collation aware Levenshtein function. + */ + public static class Levenshtein{ + + /** + * Implementation of SubstringEquals interface for collation aware comparison of two substrings. + */ + private static class CollationSubstringEquals implements UTF8String.SubstringEquals { + private final int collationId; + private final UTF8String left, right; + + CollationSubstringEquals(int collationId) { + this.collationId = collationId; + this.left = new UTF8String(); + this.right = new UTF8String(); + } + + @Override + public boolean equals(UTF8String left, UTF8String right, int posLeft, int posRight, + int lenLeft, int lenRight) { + this.left.moveAddress(left, posLeft, lenLeft); + this.right.moveAddress(right, posRight, lenRight); + return CollationFactory.fetchCollation(collationId).equalsFunction + .apply(this.left, this.right); + } + } + + public static Integer exec(final UTF8String left, final UTF8String right, final int collationId){ + CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); + + if (collation.supportsBinaryEquality){ + return left.levenshteinDistance(right); + } + else{ Review Comment: ```suggestion else { ``` ########## common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java: ########## @@ -722,6 +722,65 @@ public static UTF8String execLowercase( } } + /** + * Utility class for collation aware Levenshtein function. + */ + public static class Levenshtein{ + + /** + * Implementation of SubstringEquals interface for collation aware comparison of two substrings. + */ + private static class CollationSubstringEquals implements UTF8String.SubstringEquals { + private final int collationId; + private final UTF8String left, right; + + CollationSubstringEquals(int collationId) { + this.collationId = collationId; + this.left = new UTF8String(); + this.right = new UTF8String(); + } + + @Override + public boolean equals(UTF8String left, UTF8String right, int posLeft, int posRight, + int lenLeft, int lenRight) { + this.left.moveAddress(left, posLeft, lenLeft); + this.right.moveAddress(right, posRight, lenRight); + return CollationFactory.fetchCollation(collationId).equalsFunction + .apply(this.left, this.right); + } + } + + public static Integer exec(final UTF8String left, final UTF8String right, final int collationId){ + CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); + + if (collation.supportsBinaryEquality){ Review Comment: ```suggestion if (collation.supportsBinaryEquality) { ``` ########## common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java: ########## @@ -722,6 +722,65 @@ public static UTF8String execLowercase( } } + /** + * Utility class for collation aware Levenshtein function. + */ + public static class Levenshtein{ + + /** + * Implementation of SubstringEquals interface for collation aware comparison of two substrings. + */ + private static class CollationSubstringEquals implements UTF8String.SubstringEquals { + private final int collationId; + private final UTF8String left, right; + + CollationSubstringEquals(int collationId) { + this.collationId = collationId; + this.left = new UTF8String(); + this.right = new UTF8String(); + } + + @Override + public boolean equals(UTF8String left, UTF8String right, int posLeft, int posRight, + int lenLeft, int lenRight) { + this.left.moveAddress(left, posLeft, lenLeft); + this.right.moveAddress(right, posRight, lenRight); + return CollationFactory.fetchCollation(collationId).equalsFunction + .apply(this.left, this.right); + } + } + + public static Integer exec(final UTF8String left, final UTF8String right, final int collationId){ + CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); + + if (collation.supportsBinaryEquality){ + return left.levenshteinDistance(right); + } + else{ + return left.levenshteinDistance(right, new CollationSubstringEquals(collationId)); + } + } + + public static Integer execWithThreshold(final UTF8String left, final UTF8String right, final int threshold, final int collationId){ + CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); + + if (collation.supportsBinaryEquality){ + return left.levenshteinDistance(right, threshold); + } + else{ + return left.levenshteinDistance(right, threshold, new CollationSubstringEquals(collationId)); + } Review Comment: We need to split logic into `execBinary` and `execNonBinary` here, please look at examples such as https://github.com/apache/spark/pull/45963/files#diff-3095052ebd126ed810dceb295897d0d7b1fba11b19a4da371ab43a2f41e96313R465 ########## common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java: ########## @@ -722,6 +722,65 @@ public static UTF8String execLowercase( } } + /** + * Utility class for collation aware Levenshtein function. + */ + public static class Levenshtein{ + + /** + * Implementation of SubstringEquals interface for collation aware comparison of two substrings. + */ + private static class CollationSubstringEquals implements UTF8String.SubstringEquals { + private final int collationId; + private final UTF8String left, right; + + CollationSubstringEquals(int collationId) { + this.collationId = collationId; + this.left = new UTF8String(); + this.right = new UTF8String(); + } + + @Override + public boolean equals(UTF8String left, UTF8String right, int posLeft, int posRight, + int lenLeft, int lenRight) { + this.left.moveAddress(left, posLeft, lenLeft); + this.right.moveAddress(right, posRight, lenRight); + return CollationFactory.fetchCollation(collationId).equalsFunction + .apply(this.left, this.right); + } + } + + public static Integer exec(final UTF8String left, final UTF8String right, final int collationId){ Review Comment: ```suggestion public static Integer exec(final UTF8String left, final UTF8String right, final int collationId) { ``` ########## sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala: ########## @@ -959,6 +959,54 @@ class CollationStringExpressionsSuite assert(collationMismatch.getErrorClass === "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE") } + test("Levenshtein expressions with collation") { + case class LevenshteinTestCase[R](l: String, r: String, c: String, result: R) + Seq( + LevenshteinTestCase("", "", "UTF8_BINARY", 0), + LevenshteinTestCase("", "something", "UTF8_BINARY", 9), + LevenshteinTestCase("a", "a", "UTF8_BINARY", 0), + LevenshteinTestCase("a", "A", "UTF8_BINARY", 1), + LevenshteinTestCase("a", "a", "UTF8_BINARY_LCASE", 0), + LevenshteinTestCase("a", "A", "UTF8_BINARY_LCASE", 0), + LevenshteinTestCase("bd", "ABc", "UTF8_BINARY_LCASE", 2), + LevenshteinTestCase("Xü", "Ü", "UTF8_BINARY_LCASE", 1), + LevenshteinTestCase("Xũ", "Üx", "UTF8_BINARY_LCASE", 2), + LevenshteinTestCase("", "something", "UTF8_BINARY_LCASE", 9), + LevenshteinTestCase("sOmeThINg", "SOMETHING", "UTF8_BINARY_LCASE", 0), + LevenshteinTestCase("sOmeThINg", "SOMETHING", "UNICODE", 5), + LevenshteinTestCase("sOmeThINg", "SOMETHING", "UNICODE_CI", 0) + ).foreach(c => { + withSQLConf(SQLConf.CODEGEN_FACTORY_MODE.key -> "CODEGEN_ONLY") { + val query = s"SELECT levenshtein(collate('${c.l}', '${c.c}'), collate('${c.r}', '${c.c}'))" + // Result & data type + checkAnswer(sql(query), Row(c.result)) + assert(sql(query).schema.fields.head.dataType.sameType(IntegerType)) + } + }) + } + + test("Levenshtein expression with collation and threshold") { + case class LevenshteinThresholdTestCase(l: String, r: String, c: String, t: Int, result: Int) + Seq( + LevenshteinThresholdTestCase("", "", "UTF8_BINARY", 0, 0), + LevenshteinThresholdTestCase("", "something", "UTF8_BINARY", 0, -1), + LevenshteinThresholdTestCase("aaa", "AAA", "UTF8_BINARY_LCASE", 0, 0), + LevenshteinThresholdTestCase("a", "b", "UTF8_BINARY_LCASE", 1, 1), + LevenshteinThresholdTestCase("Xü", "Ü", "UTF8_BINARY_LCASE", 1, 1), + LevenshteinThresholdTestCase("Xũ", "Üx", "UTF8_BINARY_LCASE", 1, -1), + LevenshteinThresholdTestCase("sOmeThINg", "SOMETHING", "UNICODE", 0, -1), + LevenshteinThresholdTestCase("sOmeThINg", "SOMETHING", "UNICODE", 10, 5), + LevenshteinThresholdTestCase("sOmeThINg", "SOMETHING", "UNICODE_CI", 0, 0), + LevenshteinThresholdTestCase("sOmeThINg", "SOMETHING", "UNICODE_CI", 10, 0) + ).foreach(c => { + val query = s"SELECT levenshtein(collate('${c.l}', '${c.c}'), " + + s"collate('${c.r}', '${c.c}'), ${c.t})" + // Result & data type + checkAnswer(sql(query), Row(c.result)) + assert(sql(query).schema.fields.head.dataType.sameType(IntegerType)) + }) + } + Review Comment: We can have more relaxed testing in this file in general, these are fine. However, we should very thoroughly perform unit tests in https://github.com/apache/spark/blob/master/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java for Levenshtein distance. Please look at those examples and add analogous tests. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org