Github user viirya commented on a diff in the pull request: https://github.com/apache/spark/pull/20858#discussion_r176901843 --- Diff: sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala --- @@ -699,3 +699,88 @@ abstract class TernaryExpression extends Expression { * and Hive function wrappers. */ trait UserDefinedExpression + +/** + * The trait covers logic for performing null save evaluation and code generation. + */ +trait NullSafeEvaluation extends Expression +{ + override def foldable: Boolean = children.forall(_.foldable) + + override def nullable: Boolean = children.exists(_.nullable) + + /** + * Default behavior of evaluation according to the default nullability of NullSafeEvaluation. + * If a class utilizing NullSaveEvaluation override [[nullable]], probably should also + * override this. + */ + override def eval(input: InternalRow): Any = + { + val values = children.map(_.eval(input)) + if (values.contains(null)) null + else nullSafeEval(values) + } + + /** + * Called by default [[eval]] implementation. If a class utilizing NullSaveEvaluation keep + * the default nullability, they can override this method to save null-check code. If we need + * full control of evaluation process, we should override [[eval]]. + */ + protected def nullSafeEval(inputs: Seq[Any]): Any = + sys.error(s"The class utilizing NullSaveEvaluation must override either eval or nullSafeEval") + + /** + * Short hand for generating of null save evaluation code. + * If either of the sub-expressions is null, the result of this computation + * is assumed to be null. + * + * @param f accepts a sequence of variable names and returns Java code to compute the output. + */ + protected def defineCodeGen( + ctx: CodegenContext, + ev: ExprCode, + f: Seq[String] => String): ExprCode = { + nullSafeCodeGen(ctx, ev, values => { + s"${ev.value} = ${f(values)};" + }) + } + + /** + * Called by expressions to generate null safe evaluation code. + * If either of the sub-expressions is null, the result of this computation + * is assumed to be null. + * + * @param f a function that accepts a sequence of non-null evaluation result names of children + * and returns Java code to compute the output. + */ + protected def nullSafeCodeGen( + ctx: CodegenContext, + ev: ExprCode, + f: Seq[String] => String): ExprCode = { + val gens = children.map(_.genCode(ctx)) + val resultCode = f(gens.map(_.value)) + + if (nullable) { + val nullSafeEval = + (s""" + ${ev.isNull} = false; // resultCode could change nullability. + $resultCode + """ /: children.zip(gens)) { --- End diff -- Use `foldLeft` for readability.
--- --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org