Repository: samza Updated Branches: refs/heads/master d7a071b34 -> aac6368a2
SAMZA-1651: Samza-sql - Implement GROUP BY SQL operator Author: Aditya Toomula <atoom...@linkedin.com> Reviewers: Srini P<spun...@linkedin.com> Closes #478 from atoomula/groupby1 Project: http://git-wip-us.apache.org/repos/asf/samza/repo Commit: http://git-wip-us.apache.org/repos/asf/samza/commit/aac6368a Tree: http://git-wip-us.apache.org/repos/asf/samza/tree/aac6368a Diff: http://git-wip-us.apache.org/repos/asf/samza/diff/aac6368a Branch: refs/heads/master Commit: aac6368a2409764a98109109547a5af62860ff0f Parents: d7a071b Author: Aditya Toomula <atoom...@linkedin.com> Authored: Tue Apr 17 16:02:49 2018 -0700 Committer: Jagadish <jvenkatra...@linkedin.com> Committed: Tue Apr 17 16:02:49 2018 -0700 ---------------------------------------------------------------------- .../sql/runner/SamzaSqlApplicationConfig.java | 13 +++ .../translator/LogicalAggregateTranslator.java | 94 ++++++++++++++++++++ .../samza/sql/translator/QueryTranslator.java | 17 +++- .../apache/samza/sql/TestQueryTranslator.java | 43 +++++++++ .../samza/sql/avro/schemas/PageViewCount.avsc | 45 ++++++++++ .../samza/sql/avro/schemas/PageViewCount.java | 56 ++++++++++++ .../samza/sql/system/TestAvroSystemFactory.java | 51 ++++++++--- .../samza/sql/testutil/SamzaSqlTestConfig.java | 12 +++ .../test/samzasql/TestSamzaSqlEndToEnd.java | 64 +++++++++++-- 9 files changed, 373 insertions(+), 22 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/samza/blob/aac6368a/samza-sql/src/main/java/org/apache/samza/sql/runner/SamzaSqlApplicationConfig.java ---------------------------------------------------------------------- diff --git a/samza-sql/src/main/java/org/apache/samza/sql/runner/SamzaSqlApplicationConfig.java b/samza-sql/src/main/java/org/apache/samza/sql/runner/SamzaSqlApplicationConfig.java index aeb7f35..bcefae2 100644 --- a/samza-sql/src/main/java/org/apache/samza/sql/runner/SamzaSqlApplicationConfig.java +++ b/samza-sql/src/main/java/org/apache/samza/sql/runner/SamzaSqlApplicationConfig.java @@ -77,6 +77,11 @@ public class SamzaSqlApplicationConfig { public static final String CFG_UDF_RESOLVER = "samza.sql.udfResolver"; public static final String CFG_FMT_UDF_RESOLVER_DOMAIN = "samza.sql.udfResolver.%s."; + + public static final String CFG_GROUPBY_WINDOW_DURATION_MS = "samza.sql.groupby.window.ms"; + + private static final long DEFAULT_GROUPBY_WINDOW_DURATION_MS = 300000; // default groupby window duration is 5 mins. + private final Map<String, RelSchemaProvider> relSchemaProvidersBySource; private final Map<String, SamzaRelConverter> samzaRelConvertersBySource; @@ -92,6 +97,8 @@ public class SamzaSqlApplicationConfig { private final List<QueryInfo> queryInfo; + private final long windowDurationMs; + public SamzaSqlApplicationConfig(Config staticConfig) { sql = fetchSqlFromConfig(staticConfig); @@ -126,6 +133,8 @@ public class SamzaSqlApplicationConfig { x -> initializePlugin("SamzaRelConverter", x.getSamzaRelConverterName(), staticConfig, CFG_FMT_SAMZA_REL_CONVERTER_DOMAIN, (o, c) -> ((SamzaRelConverterFactory) o).create(x.getSystemStream(), relSchemaProvidersBySource.get(x.getSource()), c)))); + + windowDurationMs = staticConfig.getLong(CFG_GROUPBY_WINDOW_DURATION_MS, DEFAULT_GROUPBY_WINDOW_DURATION_MS); } private static <T> T initializePlugin(String pluginName, String plugin, Config staticConfig, @@ -244,4 +253,8 @@ public class SamzaSqlApplicationConfig { public SourceResolver getSourceResolver() { return sourceResolver; } + + public long getWindowDurationMs() { + return windowDurationMs; + } } http://git-wip-us.apache.org/repos/asf/samza/blob/aac6368a/samza-sql/src/main/java/org/apache/samza/sql/translator/LogicalAggregateTranslator.java ---------------------------------------------------------------------- diff --git a/samza-sql/src/main/java/org/apache/samza/sql/translator/LogicalAggregateTranslator.java b/samza-sql/src/main/java/org/apache/samza/sql/translator/LogicalAggregateTranslator.java new file mode 100644 index 0000000..96179d6 --- /dev/null +++ b/samza-sql/src/main/java/org/apache/samza/sql/translator/LogicalAggregateTranslator.java @@ -0,0 +1,94 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one +* or more contributor license agreements. See the NOTICE file +* distributed with this work for additional information +* regarding copyright ownership. The ASF licenses this file +* to you under the Apache License, Version 2.0 (the +* "License"); you may not use this file except in compliance +* with the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an +* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +* KIND, either express or implied. See the License for the +* specific language governing permissions and limitations +* under the License. +*/ + +package org.apache.samza.sql.translator; + +import java.time.Duration; +import java.util.List; +import java.util.function.Supplier; +import org.apache.calcite.rel.logical.LogicalAggregate; +import org.apache.calcite.sql.SqlKind; +import org.apache.samza.SamzaException; +import org.apache.samza.operators.MessageStream; +import org.apache.samza.operators.functions.FoldLeftFunction; +import org.apache.samza.operators.windows.AccumulationMode; +import org.apache.samza.operators.windows.Windows; +import org.apache.samza.serializers.LongSerde; +import org.apache.samza.sql.data.SamzaSqlRelMessage; +import org.apache.samza.sql.serializers.SamzaSqlRelMessageSerdeFactory; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + + +/** + * Translator to translate the LogicalAggregate node in the relational graph to the corresponding StreamGraph + * implementation + */ +class LogicalAggregateTranslator { + + private static final Logger log = LoggerFactory.getLogger(JoinTranslator.class); + private int windowId; + + LogicalAggregateTranslator(int windowId) { + this.windowId = windowId; + } + + void translate(final LogicalAggregate aggregate, final TranslatorContext context) { + validateAggregateFunctions(aggregate); + + MessageStream<SamzaSqlRelMessage> inputStream = context.getMessageStream(aggregate.getInput().getId()); + + // At this point, the assumption is that only count function is supported. + Supplier<Long> initialValue = () -> (long) 0; + FoldLeftFunction<SamzaSqlRelMessage, Long> foldCountFn = (m, c) -> c + 1; + + MessageStream<SamzaSqlRelMessage> outputStream = + inputStream + .window(Windows.keyedTumblingWindow(m -> m, + Duration.ofMillis(context.getExecutionContext().getSamzaSqlApplicationConfig().getWindowDurationMs()), + initialValue, + foldCountFn, + new SamzaSqlRelMessageSerdeFactory.SamzaSqlRelMessageSerde(), + new LongSerde()) + .setAccumulationMode(AccumulationMode.DISCARDING), "tumblingWindow_" + windowId) + .map(windowPane -> { + List<String> fieldNames = windowPane.getKey().getKey().getSamzaSqlRelRecord().getFieldNames(); + List<Object> fieldValues = windowPane.getKey().getKey().getSamzaSqlRelRecord().getFieldValues(); + fieldNames.add(aggregate.getAggCallList().get(0).getName()); + fieldValues.add(windowPane.getMessage()); + return new SamzaSqlRelMessage(fieldNames, fieldValues); + }); + context.registerMessageStream(aggregate.getId(), outputStream); + } + + void validateAggregateFunctions(final LogicalAggregate aggregate) { + if (aggregate.getAggCallList().size() != 1) { + String errMsg = "Windowing is supported ONLY with one aggregate function but the number of given functions are " + + aggregate.getAggCallList().size(); + log.error(errMsg); + throw new SamzaException(errMsg); + } + + if (aggregate.getAggCallList().get(0).getAggregation().getKind() != SqlKind.COUNT) { + String errMsg = "Windowing is supported ONLY with COUNT aggregate function"; + log.error(errMsg); + throw new SamzaException(errMsg); + } + } +} http://git-wip-us.apache.org/repos/asf/samza/blob/aac6368a/samza-sql/src/main/java/org/apache/samza/sql/translator/QueryTranslator.java ---------------------------------------------------------------------- diff --git a/samza-sql/src/main/java/org/apache/samza/sql/translator/QueryTranslator.java b/samza-sql/src/main/java/org/apache/samza/sql/translator/QueryTranslator.java index b853537..c8d55e8 100644 --- a/samza-sql/src/main/java/org/apache/samza/sql/translator/QueryTranslator.java +++ b/samza-sql/src/main/java/org/apache/samza/sql/translator/QueryTranslator.java @@ -23,6 +23,7 @@ import org.apache.calcite.rel.RelNode; import org.apache.calcite.rel.RelRoot; import org.apache.calcite.rel.RelShuttleImpl; import org.apache.calcite.rel.core.TableScan; +import org.apache.calcite.rel.logical.LogicalAggregate; import org.apache.calcite.rel.logical.LogicalFilter; import org.apache.calcite.rel.logical.LogicalJoin; import org.apache.calcite.rel.logical.LogicalProject; @@ -64,9 +65,11 @@ public class QueryTranslator { final RelRoot relRoot = planner.plan(queryInfo.getSelectQuery()); final TranslatorContext context = new TranslatorContext(streamGraph, relRoot, executionContext); final RelNode node = relRoot.project(); - final int[] joinId = new int[1]; node.accept(new RelShuttleImpl() { + int windowId = 0; + int joinId = 0; + @Override public RelNode visit(TableScan scan) { RelNode node = super.visit(scan); @@ -91,9 +94,17 @@ public class QueryTranslator { @Override public RelNode visit(LogicalJoin join) { RelNode node = super.visit(join); - joinId[0]++; + joinId++; SourceResolver sourceResolver = context.getExecutionContext().getSamzaSqlApplicationConfig().getSourceResolver(); - new JoinTranslator(joinId[0], sourceResolver).translate(join, context); + new JoinTranslator(joinId, sourceResolver).translate(join, context); + return node; + } + + @Override + public RelNode visit(LogicalAggregate aggregate) { + RelNode node = super.visit(aggregate); + windowId++; + new LogicalAggregateTranslator(windowId).translate(aggregate, context); return node; } }); http://git-wip-us.apache.org/repos/asf/samza/blob/aac6368a/samza-sql/src/test/java/org/apache/samza/sql/TestQueryTranslator.java ---------------------------------------------------------------------- diff --git a/samza-sql/src/test/java/org/apache/samza/sql/TestQueryTranslator.java b/samza-sql/src/test/java/org/apache/samza/sql/TestQueryTranslator.java index 5309838..3365923 100644 --- a/samza-sql/src/test/java/org/apache/samza/sql/TestQueryTranslator.java +++ b/samza-sql/src/test/java/org/apache/samza/sql/TestQueryTranslator.java @@ -19,12 +19,14 @@ package org.apache.samza.sql; +import java.util.Collection; import java.util.HashMap; import java.util.Map; import org.apache.samza.SamzaException; import org.apache.samza.config.Config; import org.apache.samza.config.MapConfig; import org.apache.samza.operators.StreamGraphImpl; +import org.apache.samza.operators.spec.OperatorSpec; import org.apache.samza.runtime.LocalApplicationRunner; import org.apache.samza.sql.impl.ConfigBasedSourceResolverFactory; import org.apache.samza.sql.runner.SamzaSqlApplicationConfig; @@ -347,6 +349,7 @@ public class TestQueryTranslator { StreamGraphImpl streamGraph = new StreamGraphImpl(new LocalApplicationRunner(samzaConfig), samzaConfig); translator.translate(queryInfo, streamGraph); } + @Test public void testTranslateStreamTableInnerJoin() { Map<String, String> config = SamzaSqlTestConfig.fetchStaticConfigsWithFactories(configs, 10); @@ -464,4 +467,44 @@ public class TestQueryTranslator { Assert.assertEquals("sql-job-1-partition_by-stream_1", streamGraph.getInputOperators().keySet().stream().skip(2).findFirst().get().getPhysicalName()); } + + @Test + public void testTranslateGroupBy() { + Map<String, String> config = SamzaSqlTestConfig.fetchStaticConfigsWithFactories(configs, 10); + String sql = + "Insert into testavro.pageViewCountTopic" + + " select 'SampleJob' as jobName, pv.pageKey, count(*) as `count`" + + " from testavro.PAGEVIEW as pv" + + " where pv.pageKey = 'job' or pv.pageKey = 'inbox'" + + " group by (pv.pageKey)"; + config.put(SamzaSqlApplicationConfig.CFG_SQL_STMT, sql); + Config samzaConfig = SamzaSqlApplicationRunner.computeSamzaConfigs(true, new MapConfig(config)); + SamzaSqlApplicationConfig samzaSqlApplicationConfig = new SamzaSqlApplicationConfig(new MapConfig(config)); + QueryTranslator translator = new QueryTranslator(samzaSqlApplicationConfig); + SamzaSqlQueryParser.QueryInfo queryInfo = samzaSqlApplicationConfig.getQueryInfo().get(0); + StreamGraphImpl streamGraph = new StreamGraphImpl(new LocalApplicationRunner(samzaConfig), samzaConfig); + translator.translate(queryInfo, streamGraph); + + Assert.assertEquals(1, streamGraph.getInputOperators().size()); + Assert.assertEquals(1, streamGraph.getOutputStreams().size()); + Assert.assertTrue(streamGraph.hasWindowOrJoins()); + Collection<OperatorSpec> operatorSpecs = streamGraph.getAllOperatorSpecs(); + } + + @Test (expected = SamzaException.class) + public void testTranslateGroupByWithSumAggregator() { + Map<String, String> config = SamzaSqlTestConfig.fetchStaticConfigsWithFactories(configs, 10); + String sql = + "Insert into testavro.pageViewCountTopic" + + " select 'SampleJob' as jobName, pv.pageKey, sum(pv.profileId) as `sum`" + + " from testavro.PAGEVIEW as pv" + " where pv.pageKey = 'job' or pv.pageKey = 'inbox'" + + " group by (pv.pageKey)"; + config.put(SamzaSqlApplicationConfig.CFG_SQL_STMT, sql); + Config samzaConfig = SamzaSqlApplicationRunner.computeSamzaConfigs(true, new MapConfig(config)); + SamzaSqlApplicationConfig samzaSqlApplicationConfig = new SamzaSqlApplicationConfig(new MapConfig(config)); + QueryTranslator translator = new QueryTranslator(samzaSqlApplicationConfig); + SamzaSqlQueryParser.QueryInfo queryInfo = samzaSqlApplicationConfig.getQueryInfo().get(0); + StreamGraphImpl streamGraph = new StreamGraphImpl(new LocalApplicationRunner(samzaConfig), samzaConfig); + translator.translate(queryInfo, streamGraph); + } } http://git-wip-us.apache.org/repos/asf/samza/blob/aac6368a/samza-sql/src/test/java/org/apache/samza/sql/avro/schemas/PageViewCount.avsc ---------------------------------------------------------------------- diff --git a/samza-sql/src/test/java/org/apache/samza/sql/avro/schemas/PageViewCount.avsc b/samza-sql/src/test/java/org/apache/samza/sql/avro/schemas/PageViewCount.avsc new file mode 100644 index 0000000..b124f6f --- /dev/null +++ b/samza-sql/src/test/java/org/apache/samza/sql/avro/schemas/PageViewCount.avsc @@ -0,0 +1,45 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one +* or more contributor license agreements. See the NOTICE file +* distributed with this work for additional information +* regarding copyright ownership. The ASF licenses this file +* to you under the Apache License, Version 2.0 (the +* "License"); you may not use this file except in compliance +* with the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an +* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +* KIND, either express or implied. See the License for the +* specific language governing permissions and limitations +* under the License. +*/ + +{ + "name": "PageViewCount", + "version" : 1, + "namespace": "org.apache.samza.sql.system.avro", + "type": "record", + "fields": [ + { + "name": "jobName", + "doc": "Job Name.", + "type": ["null", "string"], + "default":null + }, + { + "name": "pageKey", + "doc": "Page key.", + "type": ["null", "string"], + "default":null + }, + { + "name": "count", + "doc" : "Count of page views.", + "type": ["null", "int"], + "default":null + } + ] +} http://git-wip-us.apache.org/repos/asf/samza/blob/aac6368a/samza-sql/src/test/java/org/apache/samza/sql/avro/schemas/PageViewCount.java ---------------------------------------------------------------------- diff --git a/samza-sql/src/test/java/org/apache/samza/sql/avro/schemas/PageViewCount.java b/samza-sql/src/test/java/org/apache/samza/sql/avro/schemas/PageViewCount.java new file mode 100644 index 0000000..6b3d25d --- /dev/null +++ b/samza-sql/src/test/java/org/apache/samza/sql/avro/schemas/PageViewCount.java @@ -0,0 +1,56 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one +* or more contributor license agreements. See the NOTICE file +* distributed with this work for additional information +* regarding copyright ownership. The ASF licenses this file +* to you under the Apache License, Version 2.0 (the +* "License"); you may not use this file except in compliance +* with the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an +* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +* KIND, either express or implied. See the License for the +* specific language governing permissions and limitations +* under the License. +*/ + +/** + * Autogenerated by Avro + * + * DO NOT EDIT DIRECTLY + */ +package org.apache.samza.sql.avro.schemas; + +@SuppressWarnings("all") +public class PageViewCount extends org.apache.avro.specific.SpecificRecordBase implements org.apache.avro.specific.SpecificRecord { + public static final org.apache.avro.Schema SCHEMA$ = org.apache.avro.Schema.parse("{\"type\":\"record\",\"name\":\"PageViewCount\",\"namespace\":\"org.apache.samza.sql.system.avro\",\"fields\":[{\"name\":\"jobName\",\"type\":[\"null\",\"string\"],\"doc\":\"Job Name.\",\"default\":null},{\"name\":\"pageKey\",\"type\":[\"null\",\"string\"],\"doc\":\"Page key.\",\"default\":null},{\"name\":\"count\",\"type\":[\"null\",\"int\"],\"doc\":\"Count of page views.\",\"default\":null}]}"); + /** Job Name. */ + public java.lang.CharSequence jobName; + /** Page key. */ + public java.lang.CharSequence pageKey; + /** Count of page views. */ + public java.lang.Integer count; + public org.apache.avro.Schema getSchema() { return SCHEMA$; } + // Used by DatumWriter. Applications should not call. + public java.lang.Object get(int field$) { + switch (field$) { + case 0: return jobName; + case 1: return pageKey; + case 2: return count; + default: throw new org.apache.avro.AvroRuntimeException("Bad index"); + } + } + // Used by DatumReader. Applications should not call. + @SuppressWarnings(value="unchecked") + public void put(int field$, java.lang.Object value$) { + switch (field$) { + case 0: jobName = (java.lang.CharSequence)value$; break; + case 1: pageKey = (java.lang.CharSequence)value$; break; + case 2: count = (java.lang.Integer)value$; break; + default: throw new org.apache.avro.AvroRuntimeException("Bad index"); + } + } +} http://git-wip-us.apache.org/repos/asf/samza/blob/aac6368a/samza-sql/src/test/java/org/apache/samza/sql/system/TestAvroSystemFactory.java ---------------------------------------------------------------------- diff --git a/samza-sql/src/test/java/org/apache/samza/sql/system/TestAvroSystemFactory.java b/samza-sql/src/test/java/org/apache/samza/sql/system/TestAvroSystemFactory.java index 4bd9741..9a45034 100644 --- a/samza-sql/src/test/java/org/apache/samza/sql/system/TestAvroSystemFactory.java +++ b/samza-sql/src/test/java/org/apache/samza/sql/system/TestAvroSystemFactory.java @@ -47,6 +47,7 @@ import org.apache.samza.system.SystemAdmin; import org.apache.samza.system.SystemConsumer; import org.apache.samza.system.SystemFactory; import org.apache.samza.system.SystemProducer; +import org.apache.samza.system.SystemStream; import org.apache.samza.system.SystemStreamPartition; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -57,7 +58,7 @@ public class TestAvroSystemFactory implements SystemFactory { public static final String CFG_NUM_MESSAGES = "numMessages"; public static final String CFG_INCLUDE_NULL_FOREIGN_KEYS = "includeNullForeignKeys"; - public static List<OutgoingMessageEnvelope> messages = new ArrayList<>(); + public static final String CFG_SLEEP_BETWEEN_POLLS_MS = "sleepBetweenPollsMs"; private static final String[] profileNames = {"John", "Mike", "Mary", "Joe", "Brad", "Jennifer"}; private static final int[] profileZips = {94000, 94001, 94002, 94003, 94004, 94005}; @@ -65,17 +66,19 @@ public class TestAvroSystemFactory implements SystemFactory { private static final String[] phoneNumbers = {"000-000-0000", "111-111-1111", "222-222-2222", "333-333-3333", "444-444-4444", "555-555-5555"}; public static final String[] companies = {"MSFT", "LKND", "GOOG", "FB", "AMZN", "CSCO"}; - private static final String[] pagekeys = {"inbox", "home", "search", "pymk", "group", "job"}; + public static final String[] pageKeys = {"inbox", "home", "search", "pymk", "group", "job"}; + + public static List<OutgoingMessageEnvelope> messages = new ArrayList<>(); public static List<String> getPageKeyProfileNameJoin(int numMessages) { return IntStream.range(0, numMessages) - .mapToObj(i -> pagekeys[i % pagekeys.length] + "," + profileNames[i % profileNames.length]) + .mapToObj(i -> pageKeys[i % pageKeys.length] + "," + profileNames[i % profileNames.length]) .collect(Collectors.toList()); } public static List<String> getPageKeyProfileNameAddressJoin(int numMessages) { return IntStream.range(0, numMessages) - .mapToObj(i -> pagekeys[i % pagekeys.length] + "," + profileNames[i % profileNames.length] + "," + + .mapToObj(i -> pageKeys[i % pageKeys.length] + "," + profileNames[i % profileNames.length] + "," + profileZips[i % profileZips.length] + "," + streetNums[i % streetNums.length]) .collect(Collectors.toList()); } @@ -83,24 +86,38 @@ public class TestAvroSystemFactory implements SystemFactory { public static List<String> getPageKeyProfileNameJoinWithNullForeignKeys(int numMessages) { // All even profileId foreign keys are null return IntStream.range(0, numMessages / 2) - .mapToObj(i -> pagekeys[(i * 2 + 1) % pagekeys.length] + "," + profileNames[(i * 2 + 1) % profileNames.length]) + .mapToObj(i -> pageKeys[(i * 2 + 1) % pageKeys.length] + "," + profileNames[(i * 2 + 1) % profileNames.length]) .collect(Collectors.toList()); } public static List<String> getPageKeyProfileNameOuterJoinWithNullForeignKeys(int numMessages) { // All even profileId foreign keys are null return IntStream.range(0, numMessages) - .mapToObj(i -> pagekeys[i % pagekeys.length] + "," + ((i % 2 == 0) ? "null" : profileNames[i % profileNames.length])) + .mapToObj(i -> pageKeys[i % pageKeys.length] + "," + ((i % 2 == 0) ? "null" : profileNames[i % profileNames.length])) .collect(Collectors.toList()); } public static List<String> getPageKeyProfileCompanyNameJoin(int numMessages) { return IntStream.range(0, numMessages) - .mapToObj(i -> pagekeys[i % pagekeys.length] + "," + profileNames[i % profileNames.length] + + .mapToObj(i -> pageKeys[i % pageKeys.length] + "," + profileNames[i % profileNames.length] + "," + companies[i % companies.length]) .collect(Collectors.toList()); } + public static HashMap<String, Integer> getPageKeyGroupByResult(int numMessages, Set<String> includePageKeys) { + HashMap<String, Integer> pageKeyCountMap = new HashMap<>(); + int quotient = numMessages / pageKeys.length; + int remainder = numMessages % pageKeys.length; + IntStream.range(0, pageKeys.length) + .map(k -> { + if (includePageKeys.contains(pageKeys[k])) { + pageKeyCountMap.put(pageKeys[k], quotient + ((k < remainder) ? 1 : 0)); + } + return k; + }); + return pageKeyCountMap; + } + @Override public SystemConsumer getConsumer(String systemName, Config config, MetricsRegistry registry) { return new TestAvroSystemConsumer(systemName, config); @@ -120,15 +137,18 @@ public class TestAvroSystemFactory implements SystemFactory { public static final int DEFAULT_NUM_EVENTS = 10; private final int numMessages; private final boolean includeNullForeignKeys; + private final long sleepBetweenPollsMs; private final Set<SystemStreamPartition> simpleRecordMap = new HashSet<>(); private final Set<SystemStreamPartition> profileRecordMap = new HashSet<>(); private final Set<SystemStreamPartition> companyRecordMap = new HashSet<>(); private final Set<SystemStreamPartition> pageViewRecordMap = new HashSet<>(); + private final Map<SystemStreamPartition, Integer> curMessagesPerSsp = new HashMap<>(); public TestAvroSystemConsumer(String systemName, Config config) { numMessages = config.getInt(String.format("systems.%s.%s", systemName, CFG_NUM_MESSAGES), DEFAULT_NUM_EVENTS); includeNullForeignKeys = config.getBoolean(String.format("systems.%s.%s", systemName, CFG_INCLUDE_NULL_FOREIGN_KEYS), false); + sleepBetweenPollsMs = config.getLong(String.format("systems.%s.%s", systemName, CFG_SLEEP_BETWEEN_POLLS_MS), 0); } @Override @@ -153,6 +173,7 @@ public class TestAvroSystemFactory implements SystemFactory { if (systemStreamPartition.getStream().toLowerCase().contains("pageview")) { pageViewRecordMap.add(systemStreamPartition); } + curMessagesPerSsp.put(systemStreamPartition, 0); } @Override @@ -160,13 +181,19 @@ public class TestAvroSystemFactory implements SystemFactory { throws InterruptedException { Map<SystemStreamPartition, List<IncomingMessageEnvelope>> envelopeMap = new HashMap<>(); set.forEach(ssp -> { + int curMessages = curMessagesPerSsp.get(ssp); // We send num Messages and an end of stream message following that. - List<IncomingMessageEnvelope> envelopes = IntStream.range(0, numMessages + 1) - .mapToObj(i -> i < numMessages ? new IncomingMessageEnvelope(ssp, null, "key" + i, - getData(i, ssp)) : IncomingMessageEnvelope.buildEndOfStreamEnvelope(ssp)) - .collect(Collectors.toList()); + List<IncomingMessageEnvelope> envelopes = + IntStream.range(curMessages, curMessages + numMessages/4) + .mapToObj(i -> i < numMessages ? new IncomingMessageEnvelope(ssp, null, "key" + i, + getData(i, ssp)) : IncomingMessageEnvelope.buildEndOfStreamEnvelope(ssp)) + .collect(Collectors.toList()); envelopeMap.put(ssp, envelopes); + curMessagesPerSsp.put(ssp, curMessages + numMessages/4); }); + if (sleepBetweenPollsMs > 0) { + Thread.sleep(sleepBetweenPollsMs); + } return envelopeMap; } @@ -256,7 +283,7 @@ public class TestAvroSystemFactory implements SystemFactory { GenericRecord record = new GenericData.Record(PageView.SCHEMA$); // All even profileId foreign keys are null record.put("profileId", includeNullForeignKeys && (index % 2 == 0) ? null : index); - record.put("pageKey", pagekeys[index % pagekeys.length]); + record.put("pageKey", pageKeys[index % pageKeys.length]); return record; } http://git-wip-us.apache.org/repos/asf/samza/blob/aac6368a/samza-sql/src/test/java/org/apache/samza/sql/testutil/SamzaSqlTestConfig.java ---------------------------------------------------------------------- diff --git a/samza-sql/src/test/java/org/apache/samza/sql/testutil/SamzaSqlTestConfig.java b/samza-sql/src/test/java/org/apache/samza/sql/testutil/SamzaSqlTestConfig.java index b8b2814..208625d 100644 --- a/samza-sql/src/test/java/org/apache/samza/sql/testutil/SamzaSqlTestConfig.java +++ b/samza-sql/src/test/java/org/apache/samza/sql/testutil/SamzaSqlTestConfig.java @@ -32,6 +32,7 @@ import org.apache.samza.sql.avro.schemas.Company; import org.apache.samza.sql.avro.schemas.ComplexRecord; import org.apache.samza.sql.avro.schemas.EnrichedPageView; import org.apache.samza.sql.avro.schemas.PageView; +import org.apache.samza.sql.avro.schemas.PageViewCount; import org.apache.samza.sql.avro.schemas.Profile; import org.apache.samza.sql.avro.schemas.SimpleRecord; import org.apache.samza.sql.fn.FlattenUdf; @@ -62,6 +63,11 @@ public class SamzaSqlTestConfig { public static Map<String, String> fetchStaticConfigsWithFactories(Map<String, String> props, int numberOfMessages, boolean includeNullForeignKeys) { + return fetchStaticConfigsWithFactories(props, numberOfMessages, includeNullForeignKeys, 0); + } + + public static Map<String, String> fetchStaticConfigsWithFactories(Map<String, String> props, int numberOfMessages, + boolean includeNullForeignKeys, long windowDurationMs) { HashMap<String, String> staticConfigs = new HashMap<>(); staticConfigs.put(JobConfig.JOB_NAME(), "sql-job"); @@ -92,6 +98,9 @@ public class SamzaSqlTestConfig { String.valueOf(numberOfMessages)); staticConfigs.put(avroSystemConfigPrefix + TestAvroSystemFactory.CFG_INCLUDE_NULL_FOREIGN_KEYS, includeNullForeignKeys ? "true" : "false"); + staticConfigs.put(avroSystemConfigPrefix + TestAvroSystemFactory.CFG_SLEEP_BETWEEN_POLLS_MS, + String.valueOf(windowDurationMs / 2)); + staticConfigs.put(SamzaSqlApplicationConfig.CFG_GROUPBY_WINDOW_DURATION_MS, String.valueOf(windowDurationMs)); staticConfigs.put(avroSamzaSqlConfigPrefix + SqlSystemSourceConfig.CFG_SAMZA_REL_CONVERTER, "avro"); staticConfigs.put(avroSamzaSqlConfigPrefix + SqlSystemSourceConfig.CFG_REL_SCHEMA_PROVIDER, "config"); @@ -129,6 +138,9 @@ public class SamzaSqlTestConfig { staticConfigs.put(configAvroRelSchemaProviderDomain + String.format(ConfigBasedAvroRelSchemaProviderFactory.CFG_SOURCE_SCHEMA, "testavro", "enrichedPageViewTopic"), EnrichedPageView.SCHEMA$.toString()); + staticConfigs.put(configAvroRelSchemaProviderDomain + String.format(ConfigBasedAvroRelSchemaProviderFactory.CFG_SOURCE_SCHEMA, + "testavro", "pageViewCountTopic"), PageViewCount.SCHEMA$.toString()); + staticConfigs.putAll(props); return staticConfigs; http://git-wip-us.apache.org/repos/asf/samza/blob/aac6368a/samza-test/src/test/java/org/apache/samza/test/samzasql/TestSamzaSqlEndToEnd.java ---------------------------------------------------------------------- diff --git a/samza-test/src/test/java/org/apache/samza/test/samzasql/TestSamzaSqlEndToEnd.java b/samza-test/src/test/java/org/apache/samza/test/samzasql/TestSamzaSqlEndToEnd.java index a41463e..d511a39 100644 --- a/samza-test/src/test/java/org/apache/samza/test/samzasql/TestSamzaSqlEndToEnd.java +++ b/samza-test/src/test/java/org/apache/samza/test/samzasql/TestSamzaSqlEndToEnd.java @@ -23,8 +23,10 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.HashMap; +import java.util.HashSet; import java.util.List; import java.util.Map; +import java.util.Set; import java.util.stream.Collectors; import java.util.stream.IntStream; import org.apache.avro.generic.GenericRecord; @@ -41,6 +43,7 @@ import org.apache.samza.system.OutgoingMessageEnvelope; import org.apache.samza.test.harness.AbstractIntegrationTestHarness; import org.junit.Assert; import org.junit.Before; +import org.junit.Ignore; import org.junit.Test; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -197,7 +200,7 @@ public class TestSamzaSqlEndToEnd extends AbstractIntegrationTestHarness { .collect(Collectors.toList()); Assert.assertEquals(numMessages, outMessages.size()); List<String> expectedOutMessages = TestAvroSystemFactory.getPageKeyProfileNameJoin(numMessages); - Assert.assertEquals(outMessages, expectedOutMessages); + Assert.assertEquals(expectedOutMessages, outMessages); } @Test @@ -265,7 +268,7 @@ public class TestSamzaSqlEndToEnd extends AbstractIntegrationTestHarness { .stream() .filter(msg -> msg.endsWith("Mike")) .collect(Collectors.toList()); - Assert.assertEquals(outMessages, expectedOutMessages); + Assert.assertEquals(expectedOutMessages, outMessages); } @Test @@ -294,7 +297,7 @@ public class TestSamzaSqlEndToEnd extends AbstractIntegrationTestHarness { // Half the foreign keys are null. Assert.assertEquals(numMessages / 2, outMessages.size()); List<String> expectedOutMessages = TestAvroSystemFactory.getPageKeyProfileNameJoinWithNullForeignKeys(numMessages); - Assert.assertEquals(outMessages, expectedOutMessages); + Assert.assertEquals(expectedOutMessages, outMessages); } @Test @@ -323,7 +326,7 @@ public class TestSamzaSqlEndToEnd extends AbstractIntegrationTestHarness { Assert.assertEquals(numMessages, outMessages.size()); List<String> expectedOutMessages = TestAvroSystemFactory.getPageKeyProfileNameOuterJoinWithNullForeignKeys(numMessages); - Assert.assertEquals(outMessages, expectedOutMessages); + Assert.assertEquals(expectedOutMessages, outMessages); } @Test @@ -352,7 +355,7 @@ public class TestSamzaSqlEndToEnd extends AbstractIntegrationTestHarness { Assert.assertEquals(numMessages, outMessages.size()); List<String> expectedOutMessages = TestAvroSystemFactory.getPageKeyProfileNameOuterJoinWithNullForeignKeys(numMessages); - Assert.assertEquals(outMessages, expectedOutMessages); + Assert.assertEquals(expectedOutMessages, outMessages); } @Test @@ -382,7 +385,7 @@ public class TestSamzaSqlEndToEnd extends AbstractIntegrationTestHarness { .collect(Collectors.toList()); Assert.assertEquals(numMessages, outMessages.size()); List<String> expectedOutMessages = TestAvroSystemFactory.getPageKeyProfileCompanyNameJoin(numMessages); - Assert.assertEquals(outMessages, expectedOutMessages); + Assert.assertEquals(expectedOutMessages, outMessages); } @Test @@ -413,7 +416,54 @@ public class TestSamzaSqlEndToEnd extends AbstractIntegrationTestHarness { Assert.assertEquals(TestAvroSystemFactory.companies.length, outMessages.size()); List<String> expectedOutMessages = TestAvroSystemFactory.getPageKeyProfileCompanyNameJoin(TestAvroSystemFactory.companies.length); - Assert.assertEquals(outMessages, expectedOutMessages); + Assert.assertEquals(expectedOutMessages, outMessages); } + // Disabling the test until SAMZA-1652 and SAMZA-1661 are fixed. + @Ignore + @Test + public void testEndToEndGroupBy() throws Exception { + int numMessages = 200; + long windowDurationMs = 200; + + TestAvroSystemFactory.messages.clear(); + Map<String, String> staticConfigs = + SamzaSqlTestConfig.fetchStaticConfigsWithFactories(configs, numMessages, false, windowDurationMs); + staticConfigs.putAll(configs); + String sql = + "Insert into testavro.pageViewCountTopic" + + " select 'SampleJob' as jobName, pv.pageKey, count(*) as `count`" + + " from testavro.PAGEVIEW as pv" + + " where pv.pageKey = 'job' or pv.pageKey = 'inbox'" + + " group by (pv.pageKey)"; + + List<String> sqlStmts = Arrays.asList(sql); + staticConfigs.put(SamzaSqlApplicationConfig.CFG_SQL_STMTS_JSON, JsonUtil.toJson(sqlStmts)); + SamzaSqlApplicationRunner runner = new SamzaSqlApplicationRunner(true, new MapConfig(staticConfigs)); + runner.runAndWaitForFinish(); + + // Let's capture the list of windows/counts per key. + HashMap<String, List<String>> pageKeyCountListMap = new HashMap<>(); + TestAvroSystemFactory.messages.stream() + .map(x -> { + String pageKey = ((GenericRecord) x.getMessage()).get("pageKey").toString(); + String count = ((GenericRecord) x.getMessage()).get("count").toString(); + pageKeyCountListMap.computeIfAbsent(pageKey, k -> new ArrayList<>()).add(count); + return pageKeyCountListMap; + }); + + HashMap<String, Integer> pageKeyCountMap = new HashMap<>(); + pageKeyCountListMap.forEach((key, list) -> { + // Check that the number of windows per key is non-zero but less than the number of input messages per key. + Assert.assertTrue(list.size() > 1 && list.size() < numMessages / TestAvroSystemFactory.pageKeys.length); + // Collapse the count of messages per key + pageKeyCountMap.put(key, list.stream().mapToInt(Integer::parseInt).sum()); + }); + + Set<String> pageKeys = new HashSet<>(Arrays.asList("job", "inbox")); + HashMap<String, Integer> expectedPageKeyCountMap = + TestAvroSystemFactory.getPageKeyGroupByResult(numMessages, pageKeys); + + Assert.assertEquals(expectedPageKeyCountMap, pageKeyCountMap); + } }