DRILL-4372: (continued) Type inference for HiveUDFs
Project: http://git-wip-us.apache.org/repos/asf/drill/repo Commit: http://git-wip-us.apache.org/repos/asf/drill/commit/9ecf4a48 Tree: http://git-wip-us.apache.org/repos/asf/drill/tree/9ecf4a48 Diff: http://git-wip-us.apache.org/repos/asf/drill/diff/9ecf4a48 Branch: refs/heads/master Commit: 9ecf4a484e2cc03f73aacd1b4f3801bb1909b71f Parents: c029335 Author: Hsuan-Yi Chu <hsua...@usc.edu> Authored: Thu Mar 3 20:14:59 2016 -0800 Committer: Hsuan-Yi Chu <hsua...@usc.edu> Committed: Wed Mar 16 20:57:16 2016 -0700 ---------------------------------------------------------------------- .../exec/expr/fn/HiveFunctionRegistry.java | 58 +++++++++++++++++++- .../drill/exec/planner/sql/HiveUDFOperator.java | 28 ++-------- .../drill/exec/fn/hive/TestInbuiltHiveUDFs.java | 28 ++++++++++ 3 files changed, 89 insertions(+), 25 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/drill/blob/9ecf4a48/contrib/storage-hive/core/src/main/java/org/apache/drill/exec/expr/fn/HiveFunctionRegistry.java ---------------------------------------------------------------------- diff --git a/contrib/storage-hive/core/src/main/java/org/apache/drill/exec/expr/fn/HiveFunctionRegistry.java b/contrib/storage-hive/core/src/main/java/org/apache/drill/exec/expr/fn/HiveFunctionRegistry.java index 728954d..9a4e210 100644 --- a/contrib/storage-hive/core/src/main/java/org/apache/drill/exec/expr/fn/HiveFunctionRegistry.java +++ b/contrib/storage-hive/core/src/main/java/org/apache/drill/exec/expr/fn/HiveFunctionRegistry.java @@ -18,18 +18,32 @@ package org.apache.drill.exec.expr.fn; import java.util.HashSet; +import java.util.List; import java.util.Set; +import java.util.Collection; +import com.google.common.collect.Lists; +import org.apache.calcite.rel.type.RelDataType; +import org.apache.calcite.rel.type.RelDataTypeFactory; +import org.apache.calcite.sql.SqlOperatorBinding; +import org.apache.calcite.sql.type.SqlReturnTypeInference; +import org.apache.calcite.sql.type.SqlTypeName; import org.apache.drill.common.config.DrillConfig; +import org.apache.drill.common.exceptions.UserException; +import org.apache.drill.common.expression.ExpressionPosition; import org.apache.drill.common.expression.FunctionCall; +import org.apache.drill.common.expression.LogicalExpression; +import org.apache.drill.common.expression.MajorTypeInLogicalExpression; import org.apache.drill.common.scanner.ClassPathScanner; import org.apache.drill.common.scanner.persistence.ScanResult; +import org.apache.drill.common.types.TypeProtos; import org.apache.drill.common.types.TypeProtos.MajorType; import org.apache.drill.common.types.TypeProtos.MinorType; import org.apache.drill.common.types.Types; import org.apache.drill.exec.expr.fn.impl.hive.ObjectInspectorHelper; import org.apache.drill.exec.planner.sql.DrillOperatorTable; import org.apache.drill.exec.planner.sql.HiveUDFOperator; +import org.apache.drill.exec.planner.sql.TypeInferenceUtils; import org.apache.hadoop.hive.ql.exec.Description; import org.apache.hadoop.hive.ql.exec.UDF; import org.apache.hadoop.hive.ql.udf.UDFType; @@ -70,7 +84,7 @@ public class HiveFunctionRegistry implements PluggableFunctionRegistry{ @Override public void register(DrillOperatorTable operatorTable) { for (String name : Sets.union(methodsGenericUDF.asMap().keySet(), methodsUDF.asMap().keySet())) { - operatorTable.add(name, new HiveUDFOperator(name.toUpperCase())); + operatorTable.add(name, new HiveUDFOperator(name.toUpperCase(), new HiveSqlReturnTypeInference())); } } @@ -204,4 +218,46 @@ public class HiveFunctionRegistry implements PluggableFunctionRegistry{ return null; } + public class HiveSqlReturnTypeInference implements SqlReturnTypeInference { + private HiveSqlReturnTypeInference() { + + } + + @Override + public RelDataType inferReturnType(SqlOperatorBinding opBinding) { + for (RelDataType type : opBinding.collectOperandTypes()) { + final TypeProtos.MinorType minorType = TypeInferenceUtils.getDrillTypeFromCalciteType(type); + if(minorType == TypeProtos.MinorType.LATE) { + return opBinding.getTypeFactory() + .createTypeWithNullability( + opBinding.getTypeFactory().createSqlType(SqlTypeName.ANY), + true); + } + } + + final FunctionCall functionCall = TypeInferenceUtils.convertSqlOperatorBindingToFunctionCall(opBinding); + final HiveFuncHolder hiveFuncHolder = getFunction(functionCall); + if(hiveFuncHolder == null) { + String operandTypes = ""; + for(int j = 0; j < opBinding.getOperandCount(); ++j) { + operandTypes += opBinding.getOperandType(j).getSqlTypeName(); + if(j < opBinding.getOperandCount() - 1) { + operandTypes += ","; + } + } + + throw UserException + .functionError() + .message(String.format("%s does not support operand types (%s)", + opBinding.getOperator().getName(), + operandTypes)) + .build(logger); + } + + return TypeInferenceUtils.createCalciteTypeWithNullability( + opBinding.getTypeFactory(), + TypeInferenceUtils.getCalciteTypeFromDrillType(hiveFuncHolder.getReturnType().getMinorType()), + hiveFuncHolder.getReturnType().getMode() != TypeProtos.DataMode.REQUIRED); + } + } } http://git-wip-us.apache.org/repos/asf/drill/blob/9ecf4a48/contrib/storage-hive/core/src/main/java/org/apache/drill/exec/planner/sql/HiveUDFOperator.java ---------------------------------------------------------------------- diff --git a/contrib/storage-hive/core/src/main/java/org/apache/drill/exec/planner/sql/HiveUDFOperator.java b/contrib/storage-hive/core/src/main/java/org/apache/drill/exec/planner/sql/HiveUDFOperator.java index a9647bd..90c4135 100644 --- a/contrib/storage-hive/core/src/main/java/org/apache/drill/exec/planner/sql/HiveUDFOperator.java +++ b/contrib/storage-hive/core/src/main/java/org/apache/drill/exec/planner/sql/HiveUDFOperator.java @@ -18,28 +18,20 @@ package org.apache.drill.exec.planner.sql; -import com.fasterxml.jackson.databind.type.TypeFactory; -import org.apache.calcite.rel.type.RelDataType; -import org.apache.calcite.rel.type.RelDataTypeFactory; -import org.apache.calcite.sql.SqlCall; import org.apache.calcite.sql.SqlCallBinding; import org.apache.calcite.sql.SqlFunction; import org.apache.calcite.sql.SqlFunctionCategory; import org.apache.calcite.sql.SqlIdentifier; import org.apache.calcite.sql.SqlOperandCountRange; import org.apache.calcite.sql.SqlOperator; -import org.apache.calcite.sql.SqlOperatorBinding; import org.apache.calcite.sql.parser.SqlParserPos; import org.apache.calcite.sql.type.SqlOperandCountRanges; import org.apache.calcite.sql.type.SqlOperandTypeChecker; -import org.apache.calcite.sql.type.SqlTypeName; -import org.apache.calcite.sql.validate.SqlValidator; -import org.apache.calcite.sql.validate.SqlValidatorScope; +import org.apache.calcite.sql.type.SqlReturnTypeInference; public class HiveUDFOperator extends SqlFunction { - - public HiveUDFOperator(String name) { - super(new SqlIdentifier(name, SqlParserPos.ZERO), DynamicReturnType.INSTANCE, null, new ArgChecker(), null, + public HiveUDFOperator(String name, SqlReturnTypeInference sqlReturnTypeInference) { + super(new SqlIdentifier(name, SqlParserPos.ZERO), sqlReturnTypeInference, null, new ArgChecker(), null, SqlFunctionCategory.USER_DEFINED_FUNCTION); } @@ -51,19 +43,7 @@ public class HiveUDFOperator extends SqlFunction { return false; } - @Override - public RelDataType deriveType(SqlValidator validator, SqlValidatorScope scope, SqlCall call) { - RelDataTypeFactory factory = validator.getTypeFactory(); - return factory.createTypeWithNullability(factory.createSqlType(SqlTypeName.ANY), true); - } - - @Override - public RelDataType inferReturnType(SqlOperatorBinding opBinding) { - RelDataTypeFactory factory = opBinding.getTypeFactory(); - return factory.createTypeWithNullability(factory.createSqlType(SqlTypeName.ANY), true); - } - - /** Argument Checker for variable number of arguments */ + /** Argument Checker for variable number of arguments */ public static class ArgChecker implements SqlOperandTypeChecker { public static ArgChecker INSTANCE = new ArgChecker(); http://git-wip-us.apache.org/repos/asf/drill/blob/9ecf4a48/contrib/storage-hive/core/src/test/java/org/apache/drill/exec/fn/hive/TestInbuiltHiveUDFs.java ---------------------------------------------------------------------- diff --git a/contrib/storage-hive/core/src/test/java/org/apache/drill/exec/fn/hive/TestInbuiltHiveUDFs.java b/contrib/storage-hive/core/src/test/java/org/apache/drill/exec/fn/hive/TestInbuiltHiveUDFs.java index aba7573..1439062 100644 --- a/contrib/storage-hive/core/src/test/java/org/apache/drill/exec/fn/hive/TestInbuiltHiveUDFs.java +++ b/contrib/storage-hive/core/src/test/java/org/apache/drill/exec/fn/hive/TestInbuiltHiveUDFs.java @@ -43,4 +43,32 @@ public class TestInbuiltHiveUDFs extends HiveTestBase { .baselineValues(new Object[] { null }) .go(); } + + @Test + public void testReflect() throws Exception { + final String query = "select reflect('java.lang.Math', 'round', cast(2 as float)) as col \n" + + "from hive.kv \n" + + "limit 1"; + + testBuilder() + .sqlQuery(query) + .unOrdered() + .baselineColumns("col") + .baselineValues("2") + .go(); + } + + @Test + public void testAbs() throws Exception { + final String query = "select reflect('java.lang.Math', 'abs', cast(-2 as double)) as col \n" + + "from hive.kv \n" + + "limit 1"; + + testBuilder() + .sqlQuery(query) + .unOrdered() + .baselineColumns("col") + .baselineValues("2.0") + .go(); + } }