This is an automated email from the ASF dual-hosted git repository. dwysakowicz pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/flink.git
The following commit(s) were added to refs/heads/master by this push: new d156d09 [FLINK-16830][table-api] Let users use Row/List/Map/Seq directly in Expression DSL d156d09 is described below commit d156d09a6cb3f21245da4ae73b263ee4c8dfe4b7 Author: Dawid Wysakowicz <dwysakow...@apache.org> AuthorDate: Mon Mar 30 07:48:42 2020 +0200 [FLINK-16830][table-api] Let users use Row/List/Map/Seq directly in Expression DSL This PR implements conversion logic from Row/List/Map/Seq to a corresponding Expression. From now on users can use those types in the Expression DSL without the need to converting those types manually. It is also a prerequisite to use those type directly in TableEnvironment#fromValues. --- .../table/expressions/ApiExpressionUtils.java | 150 ++++++++++++++++++++- .../table/expressions/ObjectToExpressionTest.java | 85 ++++++++++++ .../org/apache/flink/table/api/expressionDsl.scala | 96 +++++++------ .../api/ExpressionsConsistencyCheckTest.scala | 7 +- .../flink/table/api/ImplicitConversionsTest.scala | 91 +++++++++++++ 5 files changed, 376 insertions(+), 53 deletions(-) diff --git a/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/expressions/ApiExpressionUtils.java b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/expressions/ApiExpressionUtils.java index 92f8d83..1d8998a 100644 --- a/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/expressions/ApiExpressionUtils.java +++ b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/expressions/ApiExpressionUtils.java @@ -24,15 +24,26 @@ import org.apache.flink.table.api.DataTypes; import org.apache.flink.table.api.Table; import org.apache.flink.table.api.ValidationException; import org.apache.flink.table.functions.BuiltInFunctionDefinition; +import org.apache.flink.table.functions.BuiltInFunctionDefinitions; import org.apache.flink.table.functions.FunctionDefinition; import org.apache.flink.table.functions.FunctionIdentifier; import org.apache.flink.table.functions.FunctionKind; import org.apache.flink.table.operations.QueryOperation; import org.apache.flink.table.types.DataType; +import org.apache.flink.types.Row; +import java.lang.reflect.Array; +import java.lang.reflect.InvocationTargetException; +import java.lang.reflect.Method; +import java.math.BigDecimal; +import java.util.ArrayList; import java.util.Arrays; import java.util.List; +import java.util.Map; +import java.util.Optional; import java.util.stream.Collectors; +import java.util.stream.IntStream; +import java.util.stream.Stream; /** * Utilities for API-specific {@link Expression}s. @@ -52,14 +63,151 @@ public final class ApiExpressionUtils { // private } + /** + * Converts a given object to an expression. + * + * <p>It converts: + * <ul> + * <li>{@link Row} to a call to a row constructor expression</li> + * <li>{@link Map} to a call to a map constructor expression</li> + * <li>{@link List} to a call to an array constructor expression</li> + * <li>arrays to a call to an array constructor expression</li> + * <li>Scala's {@code Seq} to an array constructor via reflection</li> + * <li>Scala's {@code Map} to a map constructor via reflection</li> + * <li>Scala's {@code BigDecimal} to a DECIMAL literal</li> + * <li>if none of the above applies, the function tries to convert the object + * to a value literal with {@link #valueLiteral(Object)}</li> + * </ul> + * + * @param expression An object to convert to an expression + */ public static Expression objectToExpression(Object expression) { if (expression instanceof ApiExpression) { return ((ApiExpression) expression).toExpr(); } else if (expression instanceof Expression) { return (Expression) expression; + } else if (expression instanceof Row) { + return convertRow((Row) expression); + } else if (expression instanceof Map) { + return convertJavaMap((Map<?, ?>) expression); + } else if (expression instanceof List) { + return convertJavaList((List<?>) expression); + } else if (expression.getClass().isArray()) { + return convertArray(expression); } else { - return valueLiteral(expression); + return convertScala(expression).orElseGet(() -> valueLiteral(expression)); + } + } + + private static Expression convertRow(Row expression) { + List<Expression> fields = IntStream.range(0, expression.getArity()) + .mapToObj(expression::getField) + .map(ApiExpressionUtils::objectToExpression) + .collect(Collectors.toList()); + + return unresolvedCall(BuiltInFunctionDefinitions.ROW, fields); + } + + private static Expression convertJavaMap(Map<?, ?> expression) { + List<Expression> entries = expression.entrySet() + .stream() + .flatMap(e -> Stream.of( + objectToExpression(e.getKey()), + objectToExpression(e.getValue()) + )).collect(Collectors.toList()); + + return unresolvedCall(BuiltInFunctionDefinitions.MAP, entries); + } + + private static Expression convertJavaList(List<?> expression) { + List<Expression> entries = expression + .stream() + .map(ApiExpressionUtils::objectToExpression) + .collect(Collectors.toList()); + + return unresolvedCall(BuiltInFunctionDefinitions.ARRAY, entries); + } + + private static Expression convertArray(Object expression) { + int length = Array.getLength(expression); + List<Expression> entries = IntStream.range(0, length) + .mapToObj(idx -> Array.get(expression, idx)) + .map(ApiExpressionUtils::objectToExpression) + .collect(Collectors.toList()); + return unresolvedCall(BuiltInFunctionDefinitions.ARRAY, entries); + } + + private static Optional<Expression> convertScala(Object obj) { + try { + Optional<Expression> array = convertScalaSeq(obj); + if (array.isPresent()) { + return array; + } + + Optional<Expression> bigDecimal = convertScalaBigDecimal(obj); + if (bigDecimal.isPresent()) { + return bigDecimal; + } + + return convertScalaMap(obj); + } catch (Exception e) { + return Optional.empty(); + } + } + + private static Optional<Expression> convertScalaMap(Object obj) + throws ClassNotFoundException, NoSuchMethodException, IllegalAccessException, InvocationTargetException { + Class<?> mapClass = Class.forName("scala.collection.Map"); + if (mapClass.isAssignableFrom(obj.getClass())) { + Class<?> seqClass = Class.forName("scala.collection.Seq"); + Class<?> productClass = Class.forName("scala.Product"); + Method getElement = productClass.getMethod("productElement", int.class); + Method toSeq = mapClass.getMethod("toSeq"); + Method getMethod = seqClass.getMethod("apply", Object.class); + Method lengthMethod = seqClass.getMethod("length"); + + Object mapAsSeq = toSeq.invoke(obj); + List<Expression> entries = new ArrayList<>(); + for (int i = 0; i < (Integer) lengthMethod.invoke(mapAsSeq); i++) { + Object mapEntry = getMethod.invoke(mapAsSeq, i); + + Object key = getElement.invoke(mapEntry, 0); + Object value = getElement.invoke(mapEntry, 1); + entries.add(objectToExpression(key)); + entries.add(objectToExpression(value)); + } + + return Optional.of(unresolvedCall(BuiltInFunctionDefinitions.MAP, entries)); + } + return Optional.empty(); + } + + private static Optional<Expression> convertScalaSeq(Object obj) + throws ClassNotFoundException, NoSuchMethodException, IllegalAccessException, InvocationTargetException { + Class<?> seqClass = Class.forName("scala.collection.Seq"); + if (seqClass.isAssignableFrom(obj.getClass())) { + Method getMethod = seqClass.getMethod("apply", Object.class); + Method lengthMethod = seqClass.getMethod("length"); + + List<Expression> entries = new ArrayList<>(); + for (int i = 0; i < (Integer) lengthMethod.invoke(obj); i++) { + entries.add(objectToExpression(getMethod.invoke(obj, i))); + } + + return Optional.of(unresolvedCall(BuiltInFunctionDefinitions.ARRAY, entries)); + } + return Optional.empty(); + } + + private static Optional<Expression> convertScalaBigDecimal(Object obj) + throws ClassNotFoundException, NoSuchMethodException, IllegalAccessException, InvocationTargetException { + Class<?> decimalClass = Class.forName("scala.math.BigDecimal"); + if (decimalClass.equals(obj.getClass())) { + Method toJava = decimalClass.getMethod("bigDecimal"); + BigDecimal bigDecimal = (BigDecimal) toJava.invoke(obj); + return Optional.of(valueLiteral(bigDecimal)); } + return Optional.empty(); } public static Expression unwrapFromApi(Expression expression) { diff --git a/flink-table/flink-table-api-java/src/test/java/org/apache/flink/table/expressions/ObjectToExpressionTest.java b/flink-table/flink-table-api-java/src/test/java/org/apache/flink/table/expressions/ObjectToExpressionTest.java new file mode 100644 index 0000000..cb3fc25 --- /dev/null +++ b/flink-table/flink-table-api-java/src/test/java/org/apache/flink/table/expressions/ObjectToExpressionTest.java @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.expressions; + +import org.apache.flink.types.Row; + +import org.junit.Test; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static java.util.Arrays.asList; +import static java.util.Collections.singletonList; +import static org.apache.flink.table.api.Expressions.array; +import static org.apache.flink.table.api.Expressions.map; +import static org.apache.flink.table.api.Expressions.row; +import static org.apache.flink.table.expressions.ApiExpressionUtils.objectToExpression; +import static org.apache.flink.table.expressions.ApiExpressionUtils.unwrapFromApi; +import static org.hamcrest.CoreMatchers.equalTo; +import static org.junit.Assert.assertThat; + +/** + * Tests for converting an object to a {@link Expression} via {@link ApiExpressionUtils#objectToExpression(Object)}. + */ +public class ObjectToExpressionTest { + + @Test + public void testListConversion() { + Expression expr = objectToExpression(asList(1, 2)); + + assertThatEquals(expr, array(1, 2)); + } + + @Test + public void testNestedListConversion() { + Expression expr = objectToExpression(asList(singletonList(1), singletonList(2))); + + assertThatEquals(expr, array(array(1), array(2))); + } + + @Test + public void testMapConversion() { + Map<String, List<Integer>> map = new HashMap<>(); + map.put("key1", singletonList(2)); + map.put("key2", asList(1, 2)); + + Expression expr = objectToExpression(map); + assertThatEquals( + expr, + map( + "key1", array(2), + "key2", array(1, 2) + ) + ); + } + + @Test + public void testRowConversion() { + Expression expr = objectToExpression(Row.of(1, "ABC", new int[]{1, 2, 3})); + + assertThatEquals(expr, row(1, "ABC", array(1, 2, 3))); + } + + private static void assertThatEquals(Expression actual, Expression expected) { + assertThat(unwrapFromApi(actual), equalTo(unwrapFromApi(expected))); + } + +} diff --git a/flink-table/flink-table-api-scala/src/main/scala/org/apache/flink/table/api/expressionDsl.scala b/flink-table/flink-table-api-scala/src/main/scala/org/apache/flink/table/api/expressionDsl.scala index d424031..2a7577a 100644 --- a/flink-table/flink-table-api-scala/src/main/scala/org/apache/flink/table/api/expressionDsl.scala +++ b/flink-table/flink-table-api-scala/src/main/scala/org/apache/flink/table/api/expressionDsl.scala @@ -25,11 +25,13 @@ import org.apache.flink.table.expressions._ import org.apache.flink.table.functions.BuiltInFunctionDefinitions._ import org.apache.flink.table.functions.{ScalarFunction, TableFunction, UserDefinedAggregateFunction, UserDefinedFunctionHelper, _} import org.apache.flink.table.types.DataType +import org.apache.flink.types.Row import java.lang.{Boolean => JBoolean, Byte => JByte, Double => JDouble, Float => JFloat, Integer => JInteger, Long => JLong, Short => JShort} import java.math.{BigDecimal => JBigDecimal} import java.sql.{Date, Time, Timestamp} import java.time.{LocalDate, LocalDateTime, LocalTime} +import java.util.{List => JList, Map => JMap} import _root_.scala.language.implicitConversions @@ -239,6 +241,11 @@ trait ImplicitExpressionConversions { def expr: Expression = unresolvedRef(s.name) } + implicit class AnyWithOperations[T](e: T)(implicit toExpr: T => Expression) + extends ImplicitExpressionOperations { + def expr: Expression = toExpr(e) + } + implicit class LiteralLongExpression(l: Long) extends ImplicitExpressionOperations { def expr: Expression = valueLiteral(l) } @@ -272,17 +279,17 @@ trait ImplicitExpressionConversions { } implicit class LiteralJavaDecimalExpression(javaDecimal: JBigDecimal) - extends ImplicitExpressionOperations { + extends ImplicitExpressionOperations { def expr: Expression = valueLiteral(javaDecimal) } implicit class LiteralScalaDecimalExpression(scalaDecimal: BigDecimal) - extends ImplicitExpressionOperations { + extends ImplicitExpressionOperations { def expr: Expression = valueLiteral(scalaDecimal.bigDecimal) } implicit class LiteralSqlDateExpression(sqlDate: Date) - extends ImplicitExpressionOperations { + extends ImplicitExpressionOperations { def expr: Expression = valueLiteral(sqlDate) } @@ -292,7 +299,7 @@ trait ImplicitExpressionConversions { } implicit class LiteralSqlTimestampExpression(sqlTimestamp: Timestamp) - extends ImplicitExpressionOperations { + extends ImplicitExpressionOperations { def expr: Expression = valueLiteral(sqlTimestamp) } @@ -370,10 +377,9 @@ trait ImplicitExpressionConversions { * Creates an unresolved reference to a table's field. * * Example: - * ``` + * {{{ * tab.select($"key", $"value") - * ``` - * </pre> + * }}} */ def $(args: Any*): Expression = unresolvedRef(sc.s(args: _*)) } @@ -392,20 +398,34 @@ trait ImplicitExpressionConversions { implicit def byte2Literal(b: Byte): Expression = valueLiteral(b) + implicit def byte2Literal(b: JByte): Expression = valueLiteral(b) + implicit def short2Literal(s: Short): Expression = valueLiteral(s) + implicit def short2Literal(s: JShort): Expression = valueLiteral(s) + implicit def int2Literal(i: Int): Expression = valueLiteral(i) + implicit def int2Literal(i: JInteger): Expression = valueLiteral(i) + implicit def long2Literal(l: Long): Expression = valueLiteral(l) + implicit def long2Literal(l: JLong): Expression = valueLiteral(l) + implicit def double2Literal(d: Double): Expression = valueLiteral(d) + implicit def double2Literal(d: JDouble): Expression = valueLiteral(d) + implicit def float2Literal(d: Float): Expression = valueLiteral(d) + implicit def float2Literal(d: JFloat): Expression = valueLiteral(d) + implicit def string2Literal(str: String): Expression = valueLiteral(str) implicit def boolean2Literal(bool: Boolean): Expression = valueLiteral(bool) + implicit def boolean2Literal(bool: JBoolean): Expression = valueLiteral(bool) + implicit def javaDec2Literal(javaDec: JBigDecimal): Expression = valueLiteral(javaDec) implicit def scalaDec2Literal(scalaDec: BigDecimal): Expression = @@ -425,54 +445,28 @@ trait ImplicitExpressionConversions { implicit def localDateTime2Literal(localDateTime: LocalDateTime): Expression = valueLiteral(localDateTime) + implicit def javaList2ArrayConstructor(jList: JList[_]): Expression = { + ApiExpressionUtils.objectToExpression(jList) + } + + implicit def seq2ArrayConstructor(seq: Seq[_]): Expression = { + ApiExpressionUtils.objectToExpression(seq) + } + implicit def array2ArrayConstructor(array: Array[_]): Expression = { + ApiExpressionUtils.objectToExpression(array) + } - def createArray(elements: Array[_]): Expression = { - unresolvedCall(BuiltInFunctionDefinitions.ARRAY, elements.map(valueLiteral): _*) - } + implicit def javaMap2MapConstructor(map: JMap[_, _]): Expression = { + ApiExpressionUtils.objectToExpression(map) + } - def convertArray(array: Array[_]): Expression = array match { - // primitives - case _: Array[Boolean] => createArray(array) - case _: Array[Byte] => createArray(array) - case _: Array[Short] => createArray(array) - case _: Array[Int] => createArray(array) - case _: Array[Long] => createArray(array) - case _: Array[Float] => createArray(array) - case _: Array[Double] => createArray(array) - - // boxed types - case _: Array[JBoolean] => createArray(array) - case _: Array[JByte] => createArray(array) - case _: Array[JShort] => createArray(array) - case _: Array[JInteger] => createArray(array) - case _: Array[JLong] => createArray(array) - case _: Array[JFloat] => createArray(array) - case _: Array[JDouble] => createArray(array) - - // others - case _: Array[String] => createArray(array) - case _: Array[JBigDecimal] => createArray(array) - case _: Array[Date] => createArray(array) - case _: Array[Time] => createArray(array) - case _: Array[Timestamp] => createArray(array) - case _: Array[LocalDate] => createArray(array) - case _: Array[LocalTime] => createArray(array) - case _: Array[LocalDateTime] => createArray(array) - case bda: Array[BigDecimal] => createArray(bda.map(_.bigDecimal)) - - case _ => - // nested - if (array.length > 0 && array.head.isInstanceOf[Array[_]]) { - unresolvedCall( - BuiltInFunctionDefinitions.ARRAY, - array.map { na => convertArray(na.asInstanceOf[Array[_]]) } :_*) - } else { - throw new ValidationException("Unsupported array type.") - } - } + implicit def map2MapConstructor(map: Map[_, _]): Expression = { + ApiExpressionUtils.objectToExpression(map) + } - convertArray(array) + implicit def row2RowConstructor(rowObject: Row): Expression = { + ApiExpressionUtils.objectToExpression(rowObject) } // ---------------------------------------------------------------------------------------------- diff --git a/flink-table/flink-table-api-scala/src/test/scala/org/apache/flink/table/api/ExpressionsConsistencyCheckTest.scala b/flink-table/flink-table-api-scala/src/test/scala/org/apache/flink/table/api/ExpressionsConsistencyCheckTest.scala index f76009e..d864bdc 100644 --- a/flink-table/flink-table-api-scala/src/test/scala/org/apache/flink/table/api/ExpressionsConsistencyCheckTest.scala +++ b/flink-table/flink-table-api-scala/src/test/scala/org/apache/flink/table/api/ExpressionsConsistencyCheckTest.scala @@ -90,7 +90,7 @@ class ExpressionsConsistencyCheckTest { // Scala implicit conversions to ImplicitExpressionOperations //----------------------------------------------------------------------------------- "WithOperations", - "apiExpressionToExpression", + "AnyWithOperations", "LiteralScalaDecimalExpression", "LiteralJavaDecimalExpression", "LiteralShortExpression", @@ -127,6 +127,11 @@ class ExpressionsConsistencyCheckTest { "localDate2Literal", "float2Literal", "array2ArrayConstructor", + "seq2ArrayConstructor", + "javaList2ArrayConstructor", + "map2MapConstructor", + "javaMap2MapConstructor", + "row2RowConstructor", "tableSymbolToExpression", //----------------------------------------------------------------------------------- diff --git a/flink-table/flink-table-api-scala/src/test/scala/org/apache/flink/table/api/ImplicitConversionsTest.scala b/flink-table/flink-table-api-scala/src/test/scala/org/apache/flink/table/api/ImplicitConversionsTest.scala new file mode 100644 index 0000000..8fec57a --- /dev/null +++ b/flink-table/flink-table-api-scala/src/test/scala/org/apache/flink/table/api/ImplicitConversionsTest.scala @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.api + +import org.apache.flink.table.expressions.ApiExpressionUtils.unwrapFromApi +import org.apache.flink.table.expressions.Expression +import org.apache.flink.types.Row + +import org.hamcrest.CoreMatchers.equalTo +import org.junit.Assert.assertThat +import org.junit.Test + +/** + * Tests for conversion between objects and [[Expression]]s used in Expression DSL. + */ +class ImplicitConversionsTest extends ImplicitExpressionConversions { + @Test + def testSeqConversion(): Unit = { + val expr = Seq(1, 2).toExpr + + assertThatEquals(expr, array(1, 2)) + } + + @Test + def testSeqOfExpressionsConversion(): Unit = { + val expr = Seq(row(1, "ABC"), row(3, "DEF")).toExpr + + assertThatEquals(expr, array(row(1, "ABC"), row(3, "DEF"))) + } + + @Test + def testListConversion(): Unit = { + val expr = List(1, 2).toExpr + + assertThatEquals(expr, array(1, 2)) + } + + @Test + def testMapConversion(): Unit = { + val expr = Map("key1" -> List(2), "key2" -> List(1, 2)).toExpr + + assertThatEquals( + expr, + map( + "key1", array(2), + "key2", array(1, 2) + ) + ) + } + + @Test + def testNestedListConversion(): Unit = { + val expr = List(List(1), List(2)).toExpr + + assertThatEquals(expr, array(array(1), array(2))) + } + + @Test + def testRowConversion(): Unit = { + val expr = Row.of(Int.box(1), "ABC").toExpr + + assertThatEquals(expr, row(1, "ABC")) + } + + @Test + def testRowConversionWithScalaTypes(): Unit = { + val expr = Row.of(Int.box(1), Seq("ABC", "DEF"), BigDecimal(1234)).toExpr + + assertThatEquals(expr, row(1, array("ABC", "DEF"), BigDecimal(1234))) + } + + private def assertThatEquals(actual: Expression, expected: Expression): Unit = { + assertThat(unwrapFromApi(actual), equalTo(unwrapFromApi(expected))) + } +}