This is an automated email from the ASF dual-hosted git repository. ibzib pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push: new f4fcccc BEAM-12166:Beam Sql - Combine Accumulator return Map fails with class cast exception new fc873f0 Merge pull request #14534 from anupd22/BEAM-12166 f4fcccc is described below commit f4fccccd726481cacc47182ccf3fc12b7c93012b Author: Anup D <anu...@nokia.com> AuthorDate: Wed Apr 14 20:40:43 2021 +0530 BEAM-12166:Beam Sql - Combine Accumulator return Map fails with class cast exception --- .../extensions/sql/impl/utils/CalciteUtils.java | 21 ++-- .../sdk/extensions/sql/BeamSqlDslUdfUdafTest.java | 114 +++++++++++++++++++++ 2 files changed, 128 insertions(+), 7 deletions(-) diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/utils/CalciteUtils.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/utils/CalciteUtils.java index 10ad199..34664ac 100644 --- a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/utils/CalciteUtils.java +++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/utils/CalciteUtils.java @@ -280,7 +280,8 @@ public class CalciteUtils { /** * SQL-Java type mapping, with specified Beam rules: <br> * 1. redirect {@link AbstractInstant} to {@link Date} so Calcite can recognize it. <br> - * 2. For a list, the component type is needed to create a Sql array type. + * 2. For a list, the component type is needed to create a Sql array type. <br> + * 3. For a Map, the component type is needed to create a Sql map type. * * @param type * @return Calcite RelDataType @@ -291,13 +292,19 @@ public class CalciteUtils { return typeFactory.createJavaType(Date.class); } else if (type instanceof Class && ByteString.class.isAssignableFrom((Class<?>) type)) { return typeFactory.createJavaType(byte[].class); - } else if (type instanceof ParameterizedType - && java.util.List.class.isAssignableFrom( - (Class<?>) ((ParameterizedType) type).getRawType())) { + } else if (type instanceof ParameterizedType) { ParameterizedType parameterizedType = (ParameterizedType) type; - Class<?> genericType = (Class<?>) parameterizedType.getActualTypeArguments()[0]; - RelDataType collectionElementType = typeFactory.createJavaType(genericType); - return typeFactory.createArrayType(collectionElementType, UNLIMITED_ARRAY_SIZE); + if (java.util.List.class.isAssignableFrom((Class<?>) parameterizedType.getRawType())) { + Class<?> genericType = (Class<?>) parameterizedType.getActualTypeArguments()[0]; + RelDataType collectionElementType = typeFactory.createJavaType(genericType); + return typeFactory.createArrayType(collectionElementType, UNLIMITED_ARRAY_SIZE); + } else if (java.util.Map.class.isAssignableFrom((Class<?>) parameterizedType.getRawType())) { + Class<?> genericKeyType = (Class<?>) parameterizedType.getActualTypeArguments()[0]; + Class<?> genericValueType = (Class<?>) parameterizedType.getActualTypeArguments()[1]; + RelDataType mapElementKeyType = typeFactory.createJavaType(genericKeyType); + RelDataType mapElementValueType = typeFactory.createJavaType(genericValueType); + return typeFactory.createMapType(mapElementKeyType, mapElementValueType); + } } return typeFactory.createJavaType((Class) type); } diff --git a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/BeamSqlDslUdfUdafTest.java b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/BeamSqlDslUdfUdafTest.java index b563ea1..9f4c9a2 100644 --- a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/BeamSqlDslUdfUdafTest.java +++ b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/BeamSqlDslUdfUdafTest.java @@ -27,7 +27,11 @@ import java.sql.Time; import java.sql.Timestamp; import java.time.LocalDate; import java.time.LocalTime; +import java.util.ArrayList; import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; import java.util.Map; import java.util.stream.IntStream; import org.apache.beam.sdk.extensions.sql.impl.BeamCalciteTable; @@ -137,6 +141,56 @@ public class BeamSqlDslUdfUdafTest extends BeamSqlDslBase { pipeline.run().waitUntilFinish(); } + /** GROUP-BY with UDAF that returns Map. */ + @Test + public void testUdafWithMapOutput() throws Exception { + Schema resultType = + Schema.builder() + .addInt32Field("f_int2") + .addMapField("squareAndAccumulateInMap", FieldType.STRING, FieldType.INT32) + .build(); + + Map<String, Integer> resultMap = new HashMap<String, Integer>(); + resultMap.put("squareOf-1", 1); + resultMap.put("squareOf-2", 4); + resultMap.put("squareOf-3", 9); + resultMap.put("squareOf-4", 16); + Row row = Row.withSchema(resultType).addValues(0, resultMap).build(); + + String sql = + "SELECT f_int2,squareAndAccumulateInMap(f_int) AS `squareAndAccumulateInMap` FROM PCOLLECTION GROUP BY f_int2"; + PCollection<Row> result = + boundedInput1.apply( + "testUdafWithMapOutput", + SqlTransform.query(sql) + .registerUdaf("squareAndAccumulateInMap", new SquareAndAccumulateInMap())); + PAssert.that(result).containsInAnyOrder(row); + + pipeline.run().waitUntilFinish(); + } + + /** GROUP-BY with UDAF that returns List. */ + @Test + public void testUdafWithListOutput() throws Exception { + Schema resultType = + Schema.builder() + .addInt32Field("f_int2") + .addArrayField("squareAndAccumulateInList", FieldType.INT32) + .build(); + Row row = Row.withSchema(resultType).addValue(0).addArray(Arrays.asList(1, 4, 9, 16)).build(); + + String sql = + "SELECT f_int2,squareAndAccumulateInList(f_int) AS `squareAndAccumulateInList` FROM PCOLLECTION GROUP BY f_int2"; + PCollection<Row> result = + boundedInput1.apply( + "testUdafWithListOutput", + SqlTransform.query(sql) + .registerUdaf("squareAndAccumulateInList", new SquareAndAccumulateInList())); + PAssert.that(result).containsInAnyOrder(row); + + pipeline.run().waitUntilFinish(); + } + @Test public void testUdfWithListOutput() throws Exception { Schema resultType = Schema.builder().addArrayField("array_field", FieldType.INT64).build(); @@ -458,4 +512,64 @@ public class BeamSqlDslUdfUdafTest extends BeamSqlDslBase { return BeamCalciteTable.of(new TestBoundedTable(schema).addRows(values)); } } + + /** UDAF(CombineFn) for test, which squares each input, tags it and returns them all in a Map. */ + public static class SquareAndAccumulateInMap + extends CombineFn<Integer, Map<String, Integer>, Map<String, Integer>> { + @Override + public Map<String, Integer> createAccumulator() { + return new HashMap<String, Integer>(); + } + + @Override + public Map<String, Integer> addInput(Map<String, Integer> accumulator, Integer input) { + accumulator.put("squareOf-" + input, input * input); + return accumulator; + } + + @Override + public Map<String, Integer> mergeAccumulators(Iterable<Map<String, Integer>> accumulators) { + Map<String, Integer> merged = createAccumulator(); + for (Map<String, Integer> accumulator : accumulators) { + merged.putAll(accumulator); + } + return merged; + } + + @Override + public Map<String, Integer> extractOutput(Map<String, Integer> accumulator) { + return accumulator; + } + } + + /** UDAF(CombineFn) for test, which squares each input and returns them all in a List. */ + public static class SquareAndAccumulateInList + extends CombineFn<Integer, List<Integer>, List<Integer>> { + + @Override + public List<Integer> createAccumulator() { + return new ArrayList<Integer>(); + } + + @Override + public List<Integer> addInput(List<Integer> accumulator, Integer input) { + accumulator.add(input * input); + return accumulator; + } + + @Override + public List<Integer> mergeAccumulators(Iterable<List<Integer>> accumulators) { + List<Integer> merged = createAccumulator(); + for (List<Integer> accumulator : accumulators) { + merged.addAll(accumulator); + } + return merged; + } + + @Override + public List<Integer> extractOutput(List<Integer> accumulator) { + Collections.sort(accumulator); + return accumulator; + } + } }