This is an automated email from the ASF dual-hosted git repository.

ibzib pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new f4fcccc  BEAM-12166:Beam Sql - Combine Accumulator return Map fails 
with class cast exception
     new fc873f0  Merge pull request #14534 from anupd22/BEAM-12166
f4fcccc is described below

commit f4fccccd726481cacc47182ccf3fc12b7c93012b
Author: Anup D <anu...@nokia.com>
AuthorDate: Wed Apr 14 20:40:43 2021 +0530

    BEAM-12166:Beam Sql - Combine Accumulator return Map fails with class cast 
exception
---
 .../extensions/sql/impl/utils/CalciteUtils.java    |  21 ++--
 .../sdk/extensions/sql/BeamSqlDslUdfUdafTest.java  | 114 +++++++++++++++++++++
 2 files changed, 128 insertions(+), 7 deletions(-)

diff --git 
a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/utils/CalciteUtils.java
 
b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/utils/CalciteUtils.java
index 10ad199..34664ac 100644
--- 
a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/utils/CalciteUtils.java
+++ 
b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/utils/CalciteUtils.java
@@ -280,7 +280,8 @@ public class CalciteUtils {
   /**
    * SQL-Java type mapping, with specified Beam rules: <br>
    * 1. redirect {@link AbstractInstant} to {@link Date} so Calcite can 
recognize it. <br>
-   * 2. For a list, the component type is needed to create a Sql array type.
+   * 2. For a list, the component type is needed to create a Sql array type. 
<br>
+   * 3. For a Map, the component type is needed to create a Sql map type.
    *
    * @param type
    * @return Calcite RelDataType
@@ -291,13 +292,19 @@ public class CalciteUtils {
       return typeFactory.createJavaType(Date.class);
     } else if (type instanceof Class && 
ByteString.class.isAssignableFrom((Class<?>) type)) {
       return typeFactory.createJavaType(byte[].class);
-    } else if (type instanceof ParameterizedType
-        && java.util.List.class.isAssignableFrom(
-            (Class<?>) ((ParameterizedType) type).getRawType())) {
+    } else if (type instanceof ParameterizedType) {
       ParameterizedType parameterizedType = (ParameterizedType) type;
-      Class<?> genericType = (Class<?>) 
parameterizedType.getActualTypeArguments()[0];
-      RelDataType collectionElementType = 
typeFactory.createJavaType(genericType);
-      return typeFactory.createArrayType(collectionElementType, 
UNLIMITED_ARRAY_SIZE);
+      if (java.util.List.class.isAssignableFrom((Class<?>) 
parameterizedType.getRawType())) {
+        Class<?> genericType = (Class<?>) 
parameterizedType.getActualTypeArguments()[0];
+        RelDataType collectionElementType = 
typeFactory.createJavaType(genericType);
+        return typeFactory.createArrayType(collectionElementType, 
UNLIMITED_ARRAY_SIZE);
+      } else if (java.util.Map.class.isAssignableFrom((Class<?>) 
parameterizedType.getRawType())) {
+        Class<?> genericKeyType = (Class<?>) 
parameterizedType.getActualTypeArguments()[0];
+        Class<?> genericValueType = (Class<?>) 
parameterizedType.getActualTypeArguments()[1];
+        RelDataType mapElementKeyType = 
typeFactory.createJavaType(genericKeyType);
+        RelDataType mapElementValueType = 
typeFactory.createJavaType(genericValueType);
+        return typeFactory.createMapType(mapElementKeyType, 
mapElementValueType);
+      }
     }
     return typeFactory.createJavaType((Class) type);
   }
diff --git 
a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/BeamSqlDslUdfUdafTest.java
 
b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/BeamSqlDslUdfUdafTest.java
index b563ea1..9f4c9a2 100644
--- 
a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/BeamSqlDslUdfUdafTest.java
+++ 
b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/BeamSqlDslUdfUdafTest.java
@@ -27,7 +27,11 @@ import java.sql.Time;
 import java.sql.Timestamp;
 import java.time.LocalDate;
 import java.time.LocalTime;
+import java.util.ArrayList;
 import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
 import java.util.Map;
 import java.util.stream.IntStream;
 import org.apache.beam.sdk.extensions.sql.impl.BeamCalciteTable;
@@ -137,6 +141,56 @@ public class BeamSqlDslUdfUdafTest extends BeamSqlDslBase {
     pipeline.run().waitUntilFinish();
   }
 
+  /** GROUP-BY with UDAF that returns Map. */
+  @Test
+  public void testUdafWithMapOutput() throws Exception {
+    Schema resultType =
+        Schema.builder()
+            .addInt32Field("f_int2")
+            .addMapField("squareAndAccumulateInMap", FieldType.STRING, 
FieldType.INT32)
+            .build();
+
+    Map<String, Integer> resultMap = new HashMap<String, Integer>();
+    resultMap.put("squareOf-1", 1);
+    resultMap.put("squareOf-2", 4);
+    resultMap.put("squareOf-3", 9);
+    resultMap.put("squareOf-4", 16);
+    Row row = Row.withSchema(resultType).addValues(0, resultMap).build();
+
+    String sql =
+        "SELECT f_int2,squareAndAccumulateInMap(f_int) AS 
`squareAndAccumulateInMap` FROM PCOLLECTION GROUP BY f_int2";
+    PCollection<Row> result =
+        boundedInput1.apply(
+            "testUdafWithMapOutput",
+            SqlTransform.query(sql)
+                .registerUdaf("squareAndAccumulateInMap", new 
SquareAndAccumulateInMap()));
+    PAssert.that(result).containsInAnyOrder(row);
+
+    pipeline.run().waitUntilFinish();
+  }
+
+  /** GROUP-BY with UDAF that returns List. */
+  @Test
+  public void testUdafWithListOutput() throws Exception {
+    Schema resultType =
+        Schema.builder()
+            .addInt32Field("f_int2")
+            .addArrayField("squareAndAccumulateInList", FieldType.INT32)
+            .build();
+    Row row = Row.withSchema(resultType).addValue(0).addArray(Arrays.asList(1, 
4, 9, 16)).build();
+
+    String sql =
+        "SELECT f_int2,squareAndAccumulateInList(f_int) AS 
`squareAndAccumulateInList` FROM PCOLLECTION GROUP BY f_int2";
+    PCollection<Row> result =
+        boundedInput1.apply(
+            "testUdafWithListOutput",
+            SqlTransform.query(sql)
+                .registerUdaf("squareAndAccumulateInList", new 
SquareAndAccumulateInList()));
+    PAssert.that(result).containsInAnyOrder(row);
+
+    pipeline.run().waitUntilFinish();
+  }
+
   @Test
   public void testUdfWithListOutput() throws Exception {
     Schema resultType = Schema.builder().addArrayField("array_field", 
FieldType.INT64).build();
@@ -458,4 +512,64 @@ public class BeamSqlDslUdfUdafTest extends BeamSqlDslBase {
       return BeamCalciteTable.of(new TestBoundedTable(schema).addRows(values));
     }
   }
+
+  /** UDAF(CombineFn) for test, which squares each input, tags it and returns 
them all in a Map. */
+  public static class SquareAndAccumulateInMap
+      extends CombineFn<Integer, Map<String, Integer>, Map<String, Integer>> {
+    @Override
+    public Map<String, Integer> createAccumulator() {
+      return new HashMap<String, Integer>();
+    }
+
+    @Override
+    public Map<String, Integer> addInput(Map<String, Integer> accumulator, 
Integer input) {
+      accumulator.put("squareOf-" + input, input * input);
+      return accumulator;
+    }
+
+    @Override
+    public Map<String, Integer> mergeAccumulators(Iterable<Map<String, 
Integer>> accumulators) {
+      Map<String, Integer> merged = createAccumulator();
+      for (Map<String, Integer> accumulator : accumulators) {
+        merged.putAll(accumulator);
+      }
+      return merged;
+    }
+
+    @Override
+    public Map<String, Integer> extractOutput(Map<String, Integer> 
accumulator) {
+      return accumulator;
+    }
+  }
+
+  /** UDAF(CombineFn) for test, which squares each input and returns them all 
in a List. */
+  public static class SquareAndAccumulateInList
+      extends CombineFn<Integer, List<Integer>, List<Integer>> {
+
+    @Override
+    public List<Integer> createAccumulator() {
+      return new ArrayList<Integer>();
+    }
+
+    @Override
+    public List<Integer> addInput(List<Integer> accumulator, Integer input) {
+      accumulator.add(input * input);
+      return accumulator;
+    }
+
+    @Override
+    public List<Integer> mergeAccumulators(Iterable<List<Integer>> 
accumulators) {
+      List<Integer> merged = createAccumulator();
+      for (List<Integer> accumulator : accumulators) {
+        merged.addAll(accumulator);
+      }
+      return merged;
+    }
+
+    @Override
+    public List<Integer> extractOutput(List<Integer> accumulator) {
+      Collections.sort(accumulator);
+      return accumulator;
+    }
+  }
 }

Reply via email to