Github user mgaido91 commented on a diff in the pull request: https://github.com/apache/spark/pull/19813#discussion_r154201470 --- Diff: sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExpressionCodegen.scala --- @@ -0,0 +1,259 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.codegen + +import scala.collection.mutable + +import org.apache.spark.sql.catalyst.expressions._ + +/** + * Defines APIs used in expression code generation. + */ +object ExpressionCodegen { + + /** + * Given an expression, returns the all necessary parameters to evaluate it, so the generated + * code of this expression can be split in a function. + * The 1st string in returned tuple is the parameter strings used to call the function. + * The 2nd string in returned tuple is the parameter strings used to declare the function. + * + * Returns `None` if it can't produce valid parameters. + * + * Params to include: + * 1. Evaluated columns referred by this, children or deferred expressions. + * 2. Rows referred by this, children or deferred expressions. + * 3. Eliminated subexpressions referred bu children expressions. + */ + def getExpressionInputParams( + ctx: CodegenContext, + expr: Expression): Option[(Seq[String], Seq[String])] = { + val (inputAttrs, inputVars) = getInputVarsForChildren(ctx, expr) + val inputRows = ctx.INPUT_ROW +: getInputRowsForChildren(ctx, expr) + val subExprs = getSubExprInChildren(ctx, expr) + + val paramsFromRows = inputRows.distinct.filter(_ != null).map { row => + (row, s"InternalRow $row") + } + val paramsFromColumns = prepareFunctionParams(ctx, inputAttrs, inputVars) + val paramsFromSubExprs = getParamsForSubExprs(ctx, subExprs) + val paramsLength = getParamLength(ctx, inputAttrs, subExprs) + paramsFromRows.length + + // Maximum allowed parameter number for Java's method descriptor. + if (paramsLength > 255) { + None + } else { + val allParams = (paramsFromRows ++ paramsFromColumns ++ paramsFromSubExprs).unzip + val callParams = allParams._1.distinct + val declParams = allParams._2.distinct + Some((callParams, declParams)) + } + } + + /** + * Returns the eliminated subexpressions in the children expressions. + */ + def getSubExprInChildren(ctx: CodegenContext, expr: Expression): Seq[Expression] = { + expr.children.flatMap { child => + child.collect { + case e if ctx.subExprEliminationExprs.contains(e) => e + } + }.distinct + } + + /** + * Given the list of eliminated subexpressions used in the children expressions, returns the + * strings of funtion parameters. The first is the variable names used to call the function, + * the second is the parameters used to declare the function in generated code. + */ + def getParamsForSubExprs( + ctx: CodegenContext, + subExprs: Seq[Expression]): Seq[(String, String)] = { + subExprs.flatMap { subExpr => + val argType = ctx.javaType(subExpr.dataType) + + val subExprState = ctx.subExprEliminationExprs(subExpr) + (subExprState.value, subExprState.isNull) + + if (!subExpr.nullable || subExprState.isNull == "true" || subExprState.isNull == "false") { + Seq((subExprState.value, s"$argType ${subExprState.value}")) + } else { + Seq((subExprState.value, s"$argType ${subExprState.value}"), + (subExprState.isNull, s"boolean ${subExprState.isNull}")) + } + }.distinct + } + + /** + * Retrieves previous input rows referred by children and deferred expressions. + */ + def getInputRowsForChildren(ctx: CodegenContext, expr: Expression): Seq[String] = { + expr.children.flatMap(getInputRows(ctx, _)).distinct + } + + /** + * Given a child expression, retrieves previous input rows referred by it or deferred expressions + * which are needed to evaluate it. + */ + def getInputRows(ctx: CodegenContext, child: Expression): Seq[String] = { + child.flatMap { + // An expression directly evaluates on current input row. + case BoundReference(ordinal, _, _) if ctx.currentVars == null || + ctx.currentVars(ordinal) == null => + Seq(ctx.INPUT_ROW) + + // An expression which is not evaluated yet. Tracks down to find input rows. + case BoundReference(ordinal, _, _) if ctx.currentVars(ordinal).code != "" => + trackDownRow(ctx, ctx.currentVars(ordinal)) + + case _ => Seq.empty + }.distinct + } + + /** + * Tracks down input rows referred by the generated code snippet. + */ + def trackDownRow(ctx: CodegenContext, exprCode: ExprCode): Seq[String] = { + var exprCodes: List[ExprCode] = List(exprCode) --- End diff -- same comment as for `trackDownVar`
--- --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org