dtenedor commented on code in PR #39747: URL: https://github.com/apache/spark/pull/39747#discussion_r1089460097
########## sql/core/src/test/resources/sql-tests/inputs/string-functions.sql: ########## @@ -231,3 +231,39 @@ CREATE TEMPORARY VIEW fmtTable(fmtField) AS SELECT * FROM VALUES ('invalidFormat SELECT to_binary('abc', fmtField) FROM fmtTable; -- Clean up DROP VIEW IF EXISTS fmtTable; +-- luhn_check +select luhn_check('4111111111111111'); +select luhn_check('5500000000000004'); +select luhn_check('340000000000009'); +select luhn_check('6011000000000004'); +select luhn_check('378282246310005'); +select luhn_check('6011000990139424'); +select luhn_check('1234567890'); +select luhn_check('4111111111111'); +select luhn_check('4111111111111112'); +select luhn_check('371449635398431'); +select luhn_check('371449635398432'); +select luhn_check('30569309025904'); +select luhn_check('30569309025905'); +select luhn_check('6011111111111117'); +select luhn_check('6011111111111118'); +select luhn_check('3530111333300000'); +select luhn_check('3530111333300001'); +select luhn_check('5105105105105100'); +select luhn_check('5105105105105106'); +select luhn_check('510B105105105106'); +select luhn_check('E105105105105106'); +select luhn_check('-5105105105105106'); +select luhn_check(5105105105105106); Review Comment: maybe add some cases with whitespace to show the resulting behavior? ########## sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala: ########## @@ -3039,3 +3039,113 @@ case class SplitPart ( partNum = newChildren.apply(2)) } } + +/** + * Function to check if a given number string is a valid Luhn number. + * Returns true, if the number string is a valid Luhn number, false otherwise + */ +@ExpressionDescription( + usage = """ + _FUNC_(str ) - Checks that a string of digits is valid according to the Luhn algorithm. + This checksum function is widely applied on credit card numbers and government identification + numbers to distinguish valid numbers from mistyped, incorrect numbers. + """, + examples = """ + Examples: + > SELECT _FUNC_('8112189876'); + true + > SELECT _FUNC_('79927398713'); + true + > SELECT _FUNC_('79927398714'); + false + """, + since = "3.5.0", + group = "string_funcs") +case class Luhncheck(child: Expression) extends UnaryExpression with ExpectsInputTypes { + + override def nullable: Boolean = false + + override protected def withNewChildInternal(newChild: Expression): Luhncheck = + copy(child = newChild) + + /** + * Expected input types from child expressions. The i-th position in the returned seq indicates Review Comment: no need to copy the method comment from the base `Expression` class unchanged, you can just leave the method un-commented if it is simple and self-explanatory. ########## sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala: ########## @@ -3039,3 +3039,113 @@ case class SplitPart ( partNum = newChildren.apply(2)) } } + +/** + * Function to check if a given number string is a valid Luhn number. + * Returns true, if the number string is a valid Luhn number, false otherwise + */ +@ExpressionDescription( + usage = """ + _FUNC_(str ) - Checks that a string of digits is valid according to the Luhn algorithm. + This checksum function is widely applied on credit card numbers and government identification + numbers to distinguish valid numbers from mistyped, incorrect numbers. + """, + examples = """ + Examples: + > SELECT _FUNC_('8112189876'); + true + > SELECT _FUNC_('79927398713'); + true + > SELECT _FUNC_('79927398714'); + false + """, + since = "3.5.0", + group = "string_funcs") +case class Luhncheck(child: Expression) extends UnaryExpression with ExpectsInputTypes { + + override def nullable: Boolean = false + + override protected def withNewChildInternal(newChild: Expression): Luhncheck = + copy(child = newChild) + + /** + * Expected input types from child expressions. The i-th position in the returned seq indicates + * the type requirement for the i-th child. + * + * The possible values at each position are: + * 1. a specific data type, for example, LongType, StringType. + * 2. a non-leaf abstract data type, for example, NumericType, IntegralType, FractionalType. + */ + override def inputTypes: Seq[AbstractDataType] = Seq(StringType) + + /** + * Returns the [[DataType]] of the result of evaluating this expression. It is invalid to query + * the dataType of an unresolved expression (i.e., when `resolved` == false). + */ + override def dataType: DataType = BooleanType + + /** + * Default behavior of evaluation according to the default nullability of UnaryExpression. If + * subclass of UnaryExpression override nullable, probably should also override this. + */ + override def eval(input: InternalRow): Any = Luhncheck.isLuhnNumber(child.eval(input)) + + /** + * Returns Java source code that can be compiled to evaluate this expression. The default + * behavior is to call the eval method of the expression. Concrete expression implementations + * should override this to do actual code generation. + * + * @param ctx + * a [[CodegenContext]] + * @param ev + * an [[ExprCode]] with unique terms. + * @return + * an [[ExprCode]] containing the Java source code to generate the given expression + */ + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + defineCodeGen( + ctx, + ev, + c => { + s"""org.apache.spark.sql.catalyst.expressions.Luhncheck.isLuhnNumber($c)""" + }) + } +} + +object Luhncheck { + + /** + * Function to check if a given number string is a valid Luhn number + * + * @param numberString + * the number string to check + * @return + * true if the number string is a valid Luhn number, false otherwise + */ + def isLuhnNumber(numberString: Any): Boolean = + numberString match { + case number: UTF8String => + val digits = number.toString + // Check if all characters in the input string are digits. + if (digits.forall(_.isDigit)) { + // Reverse the string so that we can use a foldLeft function + // and iterate through the digits in reverse order. + val (checkSum, _) = digits.reverse.foldLeft((0, false)) { case ((nSum, isSecond), d) => + // Convert the digit character to an int. + val digit = d.asDigit + // Double the digit if it's the second digit in the sequence. + val doubled = if (isSecond) digit * 2 else digit + // Add the two digits of the doubled number to the sum. + val sum = nSum + (doubled % 10) + (doubled / 10) + // Toggle the isSecond flag for the next iteration. + (sum, !isSecond) + } + // Check if the final sum is divisible by 10 + checkSum % 10 == 0 + } else { + false + } Review Comment: ```suggestion } ``` ########## sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala: ########## @@ -3039,3 +3039,113 @@ case class SplitPart ( partNum = newChildren.apply(2)) } } + +/** + * Function to check if a given number string is a valid Luhn number. + * Returns true, if the number string is a valid Luhn number, false otherwise + */ +@ExpressionDescription( + usage = """ + _FUNC_(str ) - Checks that a string of digits is valid according to the Luhn algorithm. + This checksum function is widely applied on credit card numbers and government identification + numbers to distinguish valid numbers from mistyped, incorrect numbers. + """, + examples = """ + Examples: + > SELECT _FUNC_('8112189876'); + true + > SELECT _FUNC_('79927398713'); + true + > SELECT _FUNC_('79927398714'); + false + """, + since = "3.5.0", + group = "string_funcs") +case class Luhncheck(child: Expression) extends UnaryExpression with ExpectsInputTypes { + + override def nullable: Boolean = false + + override protected def withNewChildInternal(newChild: Expression): Luhncheck = + copy(child = newChild) + + /** + * Expected input types from child expressions. The i-th position in the returned seq indicates + * the type requirement for the i-th child. + * + * The possible values at each position are: + * 1. a specific data type, for example, LongType, StringType. + * 2. a non-leaf abstract data type, for example, NumericType, IntegralType, FractionalType. + */ + override def inputTypes: Seq[AbstractDataType] = Seq(StringType) + + /** + * Returns the [[DataType]] of the result of evaluating this expression. It is invalid to query + * the dataType of an unresolved expression (i.e., when `resolved` == false). + */ + override def dataType: DataType = BooleanType + + /** + * Default behavior of evaluation according to the default nullability of UnaryExpression. If + * subclass of UnaryExpression override nullable, probably should also override this. + */ + override def eval(input: InternalRow): Any = Luhncheck.isLuhnNumber(child.eval(input)) + + /** + * Returns Java source code that can be compiled to evaluate this expression. The default + * behavior is to call the eval method of the expression. Concrete expression implementations + * should override this to do actual code generation. + * + * @param ctx + * a [[CodegenContext]] + * @param ev + * an [[ExprCode]] with unique terms. + * @return + * an [[ExprCode]] containing the Java source code to generate the given expression + */ + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + defineCodeGen( + ctx, + ev, + c => { + s"""org.apache.spark.sql.catalyst.expressions.Luhncheck.isLuhnNumber($c)""" + }) + } +} + +object Luhncheck { + + /** + * Function to check if a given number string is a valid Luhn number + * + * @param numberString + * the number string to check + * @return + * true if the number string is a valid Luhn number, false otherwise + */ + def isLuhnNumber(numberString: Any): Boolean = + numberString match { + case number: UTF8String => + val digits = number.toString + // Check if all characters in the input string are digits. + if (digits.forall(_.isDigit)) { + // Reverse the string so that we can use a foldLeft function + // and iterate through the digits in reverse order. + val (checkSum, _) = digits.reverse.foldLeft((0, false)) { case ((nSum, isSecond), d) => + // Convert the digit character to an int. + val digit = d.asDigit + // Double the digit if it's the second digit in the sequence. + val doubled = if (isSecond) digit * 2 else digit + // Add the two digits of the doubled number to the sum. + val sum = nSum + (doubled % 10) + (doubled / 10) + // Toggle the isSecond flag for the next iteration. + (sum, !isSecond) + } + // Check if the final sum is divisible by 10 + checkSum % 10 == 0 + } else { + false Review Comment: ```suggestion false ``` -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org