This is an automated email from the ASF dual-hosted git repository.

dwysakowicz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git


The following commit(s) were added to refs/heads/master by this push:
     new d156d09  [FLINK-16830][table-api] Let users use Row/List/Map/Seq 
directly in Expression DSL
d156d09 is described below

commit d156d09a6cb3f21245da4ae73b263ee4c8dfe4b7
Author: Dawid Wysakowicz <dwysakow...@apache.org>
AuthorDate: Mon Mar 30 07:48:42 2020 +0200

    [FLINK-16830][table-api] Let users use Row/List/Map/Seq directly in 
Expression DSL
    
    This PR implements conversion logic from Row/List/Map/Seq to a 
corresponding Expression. From now on users can use those types in the 
Expression DSL without the need to converting those types manually.
    
    It is also a prerequisite to use those type directly in 
TableEnvironment#fromValues.
---
 .../table/expressions/ApiExpressionUtils.java      | 150 ++++++++++++++++++++-
 .../table/expressions/ObjectToExpressionTest.java  |  85 ++++++++++++
 .../org/apache/flink/table/api/expressionDsl.scala |  96 +++++++------
 .../api/ExpressionsConsistencyCheckTest.scala      |   7 +-
 .../flink/table/api/ImplicitConversionsTest.scala  |  91 +++++++++++++
 5 files changed, 376 insertions(+), 53 deletions(-)

diff --git 
a/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/expressions/ApiExpressionUtils.java
 
b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/expressions/ApiExpressionUtils.java
index 92f8d83..1d8998a 100644
--- 
a/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/expressions/ApiExpressionUtils.java
+++ 
b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/expressions/ApiExpressionUtils.java
@@ -24,15 +24,26 @@ import org.apache.flink.table.api.DataTypes;
 import org.apache.flink.table.api.Table;
 import org.apache.flink.table.api.ValidationException;
 import org.apache.flink.table.functions.BuiltInFunctionDefinition;
+import org.apache.flink.table.functions.BuiltInFunctionDefinitions;
 import org.apache.flink.table.functions.FunctionDefinition;
 import org.apache.flink.table.functions.FunctionIdentifier;
 import org.apache.flink.table.functions.FunctionKind;
 import org.apache.flink.table.operations.QueryOperation;
 import org.apache.flink.table.types.DataType;
+import org.apache.flink.types.Row;
 
+import java.lang.reflect.Array;
+import java.lang.reflect.InvocationTargetException;
+import java.lang.reflect.Method;
+import java.math.BigDecimal;
+import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.List;
+import java.util.Map;
+import java.util.Optional;
 import java.util.stream.Collectors;
+import java.util.stream.IntStream;
+import java.util.stream.Stream;
 
 /**
  * Utilities for API-specific {@link Expression}s.
@@ -52,14 +63,151 @@ public final class ApiExpressionUtils {
                // private
        }
 
+       /**
+        * Converts a given object to an expression.
+        *
+        * <p>It converts:
+        * <ul>
+        *     <li>{@link Row} to a call to a row constructor expression</li>
+        *     <li>{@link Map} to a call to a map constructor expression</li>
+        *     <li>{@link List} to a call to an array constructor 
expression</li>
+        *     <li>arrays to a call to an array constructor expression</li>
+        *     <li>Scala's {@code Seq} to an array constructor via 
reflection</li>
+        *     <li>Scala's {@code Map} to a map constructor via reflection</li>
+        *     <li>Scala's {@code BigDecimal} to a DECIMAL literal</li>
+        *     <li>if none of the above applies, the function tries to convert 
the object
+        *          to a value literal with {@link #valueLiteral(Object)}</li>
+        * </ul>
+        *
+        * @param expression An object to convert to an expression
+        */
        public static Expression objectToExpression(Object expression) {
                if (expression instanceof ApiExpression) {
                        return ((ApiExpression) expression).toExpr();
                } else if (expression instanceof Expression) {
                        return (Expression) expression;
+               } else if (expression instanceof Row) {
+                       return convertRow((Row) expression);
+               } else if (expression instanceof Map) {
+                       return convertJavaMap((Map<?, ?>) expression);
+               } else if (expression instanceof List) {
+                       return convertJavaList((List<?>) expression);
+               } else if (expression.getClass().isArray()) {
+                       return convertArray(expression);
                } else {
-                       return valueLiteral(expression);
+                       return convertScala(expression).orElseGet(() -> 
valueLiteral(expression));
+               }
+       }
+
+       private static Expression convertRow(Row expression) {
+               List<Expression> fields = IntStream.range(0, 
expression.getArity())
+                       .mapToObj(expression::getField)
+                       .map(ApiExpressionUtils::objectToExpression)
+                       .collect(Collectors.toList());
+
+               return unresolvedCall(BuiltInFunctionDefinitions.ROW, fields);
+       }
+
+       private static Expression convertJavaMap(Map<?, ?> expression) {
+               List<Expression> entries = expression.entrySet()
+                       .stream()
+                       .flatMap(e -> Stream.of(
+                               objectToExpression(e.getKey()),
+                               objectToExpression(e.getValue())
+                       )).collect(Collectors.toList());
+
+               return unresolvedCall(BuiltInFunctionDefinitions.MAP, entries);
+       }
+
+       private static Expression convertJavaList(List<?> expression) {
+               List<Expression> entries = expression
+                       .stream()
+                       .map(ApiExpressionUtils::objectToExpression)
+                       .collect(Collectors.toList());
+
+               return unresolvedCall(BuiltInFunctionDefinitions.ARRAY, 
entries);
+       }
+
+       private static Expression convertArray(Object expression) {
+               int length = Array.getLength(expression);
+               List<Expression> entries = IntStream.range(0, length)
+                       .mapToObj(idx -> Array.get(expression, idx))
+                       .map(ApiExpressionUtils::objectToExpression)
+                       .collect(Collectors.toList());
+               return unresolvedCall(BuiltInFunctionDefinitions.ARRAY, 
entries);
+       }
+
+       private static Optional<Expression> convertScala(Object obj) {
+               try {
+                       Optional<Expression> array = convertScalaSeq(obj);
+                       if (array.isPresent()) {
+                               return array;
+                       }
+
+                       Optional<Expression> bigDecimal = 
convertScalaBigDecimal(obj);
+                       if (bigDecimal.isPresent()) {
+                               return bigDecimal;
+                       }
+
+                       return convertScalaMap(obj);
+               } catch (Exception e) {
+                       return Optional.empty();
+               }
+       }
+
+       private static Optional<Expression> convertScalaMap(Object obj)
+                       throws ClassNotFoundException, NoSuchMethodException, 
IllegalAccessException, InvocationTargetException {
+               Class<?> mapClass = Class.forName("scala.collection.Map");
+               if (mapClass.isAssignableFrom(obj.getClass())) {
+                       Class<?> seqClass = 
Class.forName("scala.collection.Seq");
+                       Class<?> productClass = Class.forName("scala.Product");
+                       Method getElement = 
productClass.getMethod("productElement", int.class);
+                       Method toSeq = mapClass.getMethod("toSeq");
+                       Method getMethod = seqClass.getMethod("apply", 
Object.class);
+                       Method lengthMethod = seqClass.getMethod("length");
+
+                       Object mapAsSeq = toSeq.invoke(obj);
+                       List<Expression> entries = new ArrayList<>();
+                       for (int i = 0; i < (Integer) 
lengthMethod.invoke(mapAsSeq); i++) {
+                               Object mapEntry = getMethod.invoke(mapAsSeq, i);
+
+                               Object key = getElement.invoke(mapEntry, 0);
+                               Object value = getElement.invoke(mapEntry, 1);
+                               entries.add(objectToExpression(key));
+                               entries.add(objectToExpression(value));
+                       }
+
+                       return 
Optional.of(unresolvedCall(BuiltInFunctionDefinitions.MAP, entries));
+               }
+               return Optional.empty();
+       }
+
+       private static Optional<Expression> convertScalaSeq(Object obj)
+                       throws ClassNotFoundException, NoSuchMethodException, 
IllegalAccessException, InvocationTargetException {
+               Class<?> seqClass = Class.forName("scala.collection.Seq");
+               if (seqClass.isAssignableFrom(obj.getClass())) {
+                       Method getMethod = seqClass.getMethod("apply", 
Object.class);
+                       Method lengthMethod = seqClass.getMethod("length");
+
+                       List<Expression> entries = new ArrayList<>();
+                       for (int i = 0; i < (Integer) lengthMethod.invoke(obj); 
i++) {
+                               
entries.add(objectToExpression(getMethod.invoke(obj, i)));
+                       }
+
+                       return 
Optional.of(unresolvedCall(BuiltInFunctionDefinitions.ARRAY, entries));
+               }
+               return Optional.empty();
+       }
+
+       private static Optional<Expression> convertScalaBigDecimal(Object obj)
+                       throws ClassNotFoundException, NoSuchMethodException, 
IllegalAccessException, InvocationTargetException {
+               Class<?> decimalClass = Class.forName("scala.math.BigDecimal");
+               if (decimalClass.equals(obj.getClass())) {
+                       Method toJava = decimalClass.getMethod("bigDecimal");
+                       BigDecimal bigDecimal = (BigDecimal) toJava.invoke(obj);
+                       return Optional.of(valueLiteral(bigDecimal));
                }
+               return Optional.empty();
        }
 
        public static Expression unwrapFromApi(Expression expression) {
diff --git 
a/flink-table/flink-table-api-java/src/test/java/org/apache/flink/table/expressions/ObjectToExpressionTest.java
 
b/flink-table/flink-table-api-java/src/test/java/org/apache/flink/table/expressions/ObjectToExpressionTest.java
new file mode 100644
index 0000000..cb3fc25
--- /dev/null
+++ 
b/flink-table/flink-table-api-java/src/test/java/org/apache/flink/table/expressions/ObjectToExpressionTest.java
@@ -0,0 +1,85 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.expressions;
+
+import org.apache.flink.types.Row;
+
+import org.junit.Test;
+
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+import static java.util.Arrays.asList;
+import static java.util.Collections.singletonList;
+import static org.apache.flink.table.api.Expressions.array;
+import static org.apache.flink.table.api.Expressions.map;
+import static org.apache.flink.table.api.Expressions.row;
+import static 
org.apache.flink.table.expressions.ApiExpressionUtils.objectToExpression;
+import static 
org.apache.flink.table.expressions.ApiExpressionUtils.unwrapFromApi;
+import static org.hamcrest.CoreMatchers.equalTo;
+import static org.junit.Assert.assertThat;
+
+/**
+ * Tests for converting an object to a {@link Expression} via {@link 
ApiExpressionUtils#objectToExpression(Object)}.
+ */
+public class ObjectToExpressionTest {
+
+       @Test
+       public void testListConversion() {
+               Expression expr = objectToExpression(asList(1, 2));
+
+               assertThatEquals(expr, array(1, 2));
+       }
+
+       @Test
+       public void testNestedListConversion() {
+               Expression expr = objectToExpression(asList(singletonList(1), 
singletonList(2)));
+
+               assertThatEquals(expr, array(array(1), array(2)));
+       }
+
+       @Test
+       public void testMapConversion() {
+               Map<String, List<Integer>> map = new HashMap<>();
+               map.put("key1", singletonList(2));
+               map.put("key2", asList(1, 2));
+
+               Expression expr = objectToExpression(map);
+               assertThatEquals(
+                       expr,
+                       map(
+                               "key1", array(2),
+                               "key2", array(1, 2)
+                       )
+               );
+       }
+
+       @Test
+       public void testRowConversion() {
+               Expression expr = objectToExpression(Row.of(1, "ABC", new 
int[]{1, 2, 3}));
+
+               assertThatEquals(expr, row(1, "ABC", array(1, 2, 3)));
+       }
+
+       private static void assertThatEquals(Expression actual, Expression 
expected) {
+               assertThat(unwrapFromApi(actual), 
equalTo(unwrapFromApi(expected)));
+       }
+
+}
diff --git 
a/flink-table/flink-table-api-scala/src/main/scala/org/apache/flink/table/api/expressionDsl.scala
 
b/flink-table/flink-table-api-scala/src/main/scala/org/apache/flink/table/api/expressionDsl.scala
index d424031..2a7577a 100644
--- 
a/flink-table/flink-table-api-scala/src/main/scala/org/apache/flink/table/api/expressionDsl.scala
+++ 
b/flink-table/flink-table-api-scala/src/main/scala/org/apache/flink/table/api/expressionDsl.scala
@@ -25,11 +25,13 @@ import org.apache.flink.table.expressions._
 import org.apache.flink.table.functions.BuiltInFunctionDefinitions._
 import org.apache.flink.table.functions.{ScalarFunction, TableFunction, 
UserDefinedAggregateFunction, UserDefinedFunctionHelper, _}
 import org.apache.flink.table.types.DataType
+import org.apache.flink.types.Row
 
 import java.lang.{Boolean => JBoolean, Byte => JByte, Double => JDouble, Float 
=> JFloat, Integer => JInteger, Long => JLong, Short => JShort}
 import java.math.{BigDecimal => JBigDecimal}
 import java.sql.{Date, Time, Timestamp}
 import java.time.{LocalDate, LocalDateTime, LocalTime}
+import java.util.{List => JList, Map => JMap}
 
 import _root_.scala.language.implicitConversions
 
@@ -239,6 +241,11 @@ trait ImplicitExpressionConversions {
     def expr: Expression = unresolvedRef(s.name)
   }
 
+  implicit class AnyWithOperations[T](e: T)(implicit toExpr: T => Expression)
+      extends ImplicitExpressionOperations {
+    def expr: Expression = toExpr(e)
+  }
+
   implicit class LiteralLongExpression(l: Long) extends 
ImplicitExpressionOperations {
     def expr: Expression = valueLiteral(l)
   }
@@ -272,17 +279,17 @@ trait ImplicitExpressionConversions {
   }
 
   implicit class LiteralJavaDecimalExpression(javaDecimal: JBigDecimal)
-      extends ImplicitExpressionOperations {
+    extends ImplicitExpressionOperations {
     def expr: Expression = valueLiteral(javaDecimal)
   }
 
   implicit class LiteralScalaDecimalExpression(scalaDecimal: BigDecimal)
-      extends ImplicitExpressionOperations {
+    extends ImplicitExpressionOperations {
     def expr: Expression = valueLiteral(scalaDecimal.bigDecimal)
   }
 
   implicit class LiteralSqlDateExpression(sqlDate: Date)
-      extends ImplicitExpressionOperations {
+    extends ImplicitExpressionOperations {
     def expr: Expression = valueLiteral(sqlDate)
   }
 
@@ -292,7 +299,7 @@ trait ImplicitExpressionConversions {
   }
 
   implicit class LiteralSqlTimestampExpression(sqlTimestamp: Timestamp)
-      extends ImplicitExpressionOperations {
+    extends ImplicitExpressionOperations {
     def expr: Expression = valueLiteral(sqlTimestamp)
   }
 
@@ -370,10 +377,9 @@ trait ImplicitExpressionConversions {
      * Creates an unresolved reference to a table's field.
      *
      * Example:
-     * ```
+     * {{{
      * tab.select($"key", $"value")
-     * ```
-     * </pre>
+     * }}}
      */
     def $(args: Any*): Expression = unresolvedRef(sc.s(args: _*))
   }
@@ -392,20 +398,34 @@ trait ImplicitExpressionConversions {
 
   implicit def byte2Literal(b: Byte): Expression = valueLiteral(b)
 
+  implicit def byte2Literal(b: JByte): Expression = valueLiteral(b)
+
   implicit def short2Literal(s: Short): Expression = valueLiteral(s)
 
+  implicit def short2Literal(s: JShort): Expression = valueLiteral(s)
+
   implicit def int2Literal(i: Int): Expression = valueLiteral(i)
 
+  implicit def int2Literal(i: JInteger): Expression = valueLiteral(i)
+
   implicit def long2Literal(l: Long): Expression = valueLiteral(l)
 
+  implicit def long2Literal(l: JLong): Expression = valueLiteral(l)
+
   implicit def double2Literal(d: Double): Expression = valueLiteral(d)
 
+  implicit def double2Literal(d: JDouble): Expression = valueLiteral(d)
+
   implicit def float2Literal(d: Float): Expression = valueLiteral(d)
 
+  implicit def float2Literal(d: JFloat): Expression = valueLiteral(d)
+
   implicit def string2Literal(str: String): Expression = valueLiteral(str)
 
   implicit def boolean2Literal(bool: Boolean): Expression = valueLiteral(bool)
 
+  implicit def boolean2Literal(bool: JBoolean): Expression = valueLiteral(bool)
+
   implicit def javaDec2Literal(javaDec: JBigDecimal): Expression = 
valueLiteral(javaDec)
 
   implicit def scalaDec2Literal(scalaDec: BigDecimal): Expression =
@@ -425,54 +445,28 @@ trait ImplicitExpressionConversions {
   implicit def localDateTime2Literal(localDateTime: LocalDateTime): Expression 
=
     valueLiteral(localDateTime)
 
+  implicit def javaList2ArrayConstructor(jList: JList[_]): Expression = {
+    ApiExpressionUtils.objectToExpression(jList)
+  }
+
+  implicit def seq2ArrayConstructor(seq: Seq[_]): Expression = {
+    ApiExpressionUtils.objectToExpression(seq)
+  }
+
   implicit def array2ArrayConstructor(array: Array[_]): Expression = {
+    ApiExpressionUtils.objectToExpression(array)
+  }
 
-    def createArray(elements: Array[_]): Expression = {
-      unresolvedCall(BuiltInFunctionDefinitions.ARRAY, 
elements.map(valueLiteral): _*)
-    }
+  implicit def javaMap2MapConstructor(map: JMap[_, _]): Expression = {
+    ApiExpressionUtils.objectToExpression(map)
+  }
 
-    def convertArray(array: Array[_]): Expression = array match {
-      // primitives
-      case _: Array[Boolean] => createArray(array)
-      case _: Array[Byte] => createArray(array)
-      case _: Array[Short] => createArray(array)
-      case _: Array[Int] => createArray(array)
-      case _: Array[Long] => createArray(array)
-      case _: Array[Float] => createArray(array)
-      case _: Array[Double] => createArray(array)
-
-      // boxed types
-      case _: Array[JBoolean] => createArray(array)
-      case _: Array[JByte] => createArray(array)
-      case _: Array[JShort] => createArray(array)
-      case _: Array[JInteger] => createArray(array)
-      case _: Array[JLong] => createArray(array)
-      case _: Array[JFloat] => createArray(array)
-      case _: Array[JDouble] => createArray(array)
-
-      // others
-      case _: Array[String] => createArray(array)
-      case _: Array[JBigDecimal] => createArray(array)
-      case _: Array[Date] => createArray(array)
-      case _: Array[Time] => createArray(array)
-      case _: Array[Timestamp] => createArray(array)
-      case _: Array[LocalDate] => createArray(array)
-      case _: Array[LocalTime] => createArray(array)
-      case _: Array[LocalDateTime] => createArray(array)
-      case bda: Array[BigDecimal] => createArray(bda.map(_.bigDecimal))
-
-      case _ =>
-        // nested
-        if (array.length > 0 && array.head.isInstanceOf[Array[_]]) {
-          unresolvedCall(
-            BuiltInFunctionDefinitions.ARRAY,
-            array.map { na => convertArray(na.asInstanceOf[Array[_]]) } :_*)
-        } else {
-          throw new ValidationException("Unsupported array type.")
-        }
-    }
+  implicit def map2MapConstructor(map: Map[_, _]): Expression = {
+    ApiExpressionUtils.objectToExpression(map)
+  }
 
-    convertArray(array)
+  implicit def row2RowConstructor(rowObject: Row): Expression = {
+    ApiExpressionUtils.objectToExpression(rowObject)
   }
 
   // 
----------------------------------------------------------------------------------------------
diff --git 
a/flink-table/flink-table-api-scala/src/test/scala/org/apache/flink/table/api/ExpressionsConsistencyCheckTest.scala
 
b/flink-table/flink-table-api-scala/src/test/scala/org/apache/flink/table/api/ExpressionsConsistencyCheckTest.scala
index f76009e..d864bdc 100644
--- 
a/flink-table/flink-table-api-scala/src/test/scala/org/apache/flink/table/api/ExpressionsConsistencyCheckTest.scala
+++ 
b/flink-table/flink-table-api-scala/src/test/scala/org/apache/flink/table/api/ExpressionsConsistencyCheckTest.scala
@@ -90,7 +90,7 @@ class ExpressionsConsistencyCheckTest {
     //  Scala implicit conversions to ImplicitExpressionOperations
     
//-----------------------------------------------------------------------------------
     "WithOperations",
-    "apiExpressionToExpression",
+    "AnyWithOperations",
     "LiteralScalaDecimalExpression",
     "LiteralJavaDecimalExpression",
     "LiteralShortExpression",
@@ -127,6 +127,11 @@ class ExpressionsConsistencyCheckTest {
     "localDate2Literal",
     "float2Literal",
     "array2ArrayConstructor",
+    "seq2ArrayConstructor",
+    "javaList2ArrayConstructor",
+    "map2MapConstructor",
+    "javaMap2MapConstructor",
+    "row2RowConstructor",
     "tableSymbolToExpression",
 
     
//-----------------------------------------------------------------------------------
diff --git 
a/flink-table/flink-table-api-scala/src/test/scala/org/apache/flink/table/api/ImplicitConversionsTest.scala
 
b/flink-table/flink-table-api-scala/src/test/scala/org/apache/flink/table/api/ImplicitConversionsTest.scala
new file mode 100644
index 0000000..8fec57a
--- /dev/null
+++ 
b/flink-table/flink-table-api-scala/src/test/scala/org/apache/flink/table/api/ImplicitConversionsTest.scala
@@ -0,0 +1,91 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.api
+
+import org.apache.flink.table.expressions.ApiExpressionUtils.unwrapFromApi
+import org.apache.flink.table.expressions.Expression
+import org.apache.flink.types.Row
+
+import org.hamcrest.CoreMatchers.equalTo
+import org.junit.Assert.assertThat
+import org.junit.Test
+
+/**
+ * Tests for conversion between objects and [[Expression]]s used in Expression 
DSL.
+ */
+class ImplicitConversionsTest extends ImplicitExpressionConversions {
+  @Test
+  def testSeqConversion(): Unit = {
+    val expr = Seq(1, 2).toExpr
+
+    assertThatEquals(expr, array(1, 2))
+  }
+
+  @Test
+  def testSeqOfExpressionsConversion(): Unit = {
+    val expr = Seq(row(1, "ABC"), row(3, "DEF")).toExpr
+
+    assertThatEquals(expr, array(row(1, "ABC"), row(3, "DEF")))
+  }
+
+  @Test
+  def testListConversion(): Unit = {
+    val expr = List(1, 2).toExpr
+
+    assertThatEquals(expr, array(1, 2))
+  }
+
+  @Test
+  def testMapConversion(): Unit = {
+    val expr = Map("key1" -> List(2), "key2" -> List(1, 2)).toExpr
+
+    assertThatEquals(
+      expr,
+        map(
+          "key1", array(2),
+          "key2", array(1, 2)
+        )
+    )
+  }
+
+  @Test
+  def testNestedListConversion(): Unit = {
+    val expr = List(List(1), List(2)).toExpr
+
+    assertThatEquals(expr, array(array(1), array(2)))
+  }
+
+  @Test
+  def testRowConversion(): Unit = {
+    val expr = Row.of(Int.box(1), "ABC").toExpr
+
+    assertThatEquals(expr, row(1, "ABC"))
+  }
+
+  @Test
+  def testRowConversionWithScalaTypes(): Unit = {
+    val expr = Row.of(Int.box(1), Seq("ABC", "DEF"), BigDecimal(1234)).toExpr
+
+    assertThatEquals(expr, row(1, array("ABC", "DEF"), BigDecimal(1234)))
+  }
+
+  private def assertThatEquals(actual: Expression, expected: Expression): Unit 
= {
+    assertThat(unwrapFromApi(actual), equalTo(unwrapFromApi(expected)))
+  }
+}

Reply via email to