[
https://issues.apache.org/jira/browse/BEAM-4461?focusedWorklogId=160793&page=com.atlassian.jira.plugin.system.issuetabpanels:worklog-tabpanel#worklog-160793
]
ASF GitHub Bot logged work on BEAM-4461:
----------------------------------------
Author: ASF GitHub Bot
Created on: 30/Oct/18 19:20
Start Date: 30/Oct/18 19:20
Worklog Time Spent: 10m
Work Description: reuvenlax closed pull request #6883: [BEAM-4461]
Resubmit PR to switch SQL over to schema transfomr
URL: https://github.com/apache/beam/pull/6883
This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:
As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):
diff --git
a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/transforms/Group.java
b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/transforms/Group.java
index 73048528065..d1d4206d3e4 100644
---
a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/transforms/Group.java
+++
b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/transforms/Group.java
@@ -107,6 +107,11 @@
return new ByFields<>(FieldAccessDescriptor.withFieldNames(fieldNames));
}
+ /** Same as {@link #byFieldNames(String...)}. */
+ public static <T> ByFields<T> byFieldNames(Iterable<String> fieldNames) {
+ return new ByFields<>(FieldAccessDescriptor.withFieldNames(fieldNames));
+ }
+
/**
* Returns a transform that groups all elements in the input {@link
PCollection} keyed by the list
* of fields specified. The output of this transform will be a {@link KV}
keyed by a {@link Row}
@@ -117,6 +122,11 @@
return new ByFields<>(FieldAccessDescriptor.withFieldIds(fieldIds));
}
+ /** Same as {@link #byFieldIds(Integer...)}. */
+ public static <T> ByFields<T> byFieldIds(Iterable<Integer> fieldIds) {
+ return new ByFields<>(FieldAccessDescriptor.withFieldIds(fieldIds));
+ }
+
/**
* Returns a transform that groups all elements in the input {@link
PCollection} keyed by the
* fields specified. The output of this transform will be a {@link KV} keyed
by a {@link Row}
@@ -155,6 +165,17 @@
FieldAccessDescriptor.withFieldNames(inputFieldName), fn,
outputFieldName));
}
+ /** The same as {@link #aggregateField} but using field id. */
+ public <CombineInputT, AccumT, CombineOutputT>
CombineFieldsGlobally<InputT> aggregateField(
+ int inputFieldId,
+ CombineFn<CombineInputT, AccumT, CombineOutputT> fn,
+ String outputFieldName) {
+ return new CombineFieldsGlobally<>(
+ SchemaAggregateFn.<InputT>create()
+ .aggregateFields(
+ FieldAccessDescriptor.withFieldIds(inputFieldId), fn,
outputFieldName));
+ }
+
/**
* Build up an aggregation function over the input elements.
*
@@ -171,6 +192,14 @@
FieldAccessDescriptor.withFieldNames(inputFieldName), fn,
outputField));
}
+ /** The same as {@link #aggregateField} but using field id. */
+ public <CombineInputT, AccumT, CombineOutputT>
CombineFieldsGlobally<InputT> aggregateField(
+ int inputFielId, CombineFn<CombineInputT, AccumT, CombineOutputT> fn,
Field outputField) {
+ return new CombineFieldsGlobally<>(
+ SchemaAggregateFn.<InputT>create()
+
.aggregateFields(FieldAccessDescriptor.withFieldIds(inputFielId), fn,
outputField));
+ }
+
/**
* Build up an aggregation function over the input elements.
*
@@ -189,6 +218,16 @@
FieldAccessDescriptor.withFieldNames(inputFieldNames), fn,
outputFieldName);
}
+ /** The same as {@link #aggregateFields} but with field ids. */
+ public <CombineInputT, AccumT, CombineOutputT>
+ CombineFieldsGlobally<InputT> aggregateFieldsById(
+ List<Integer> inputFieldIds,
+ CombineFn<CombineInputT, AccumT, CombineOutputT> fn,
+ String outputFieldName) {
+ return aggregateFields(
+ FieldAccessDescriptor.withFieldIds(inputFieldIds), fn,
outputFieldName);
+ }
+
/**
* Build up an aggregation function over the input elements.
*
@@ -222,6 +261,14 @@
FieldAccessDescriptor.withFieldNames(inputFieldNames), fn,
outputField);
}
+ public <CombineInputT, AccumT, CombineOutputT>
+ CombineFieldsGlobally<InputT> aggregateFieldsById(
+ List<Integer> inputFieldIds,
+ CombineFn<CombineInputT, AccumT, CombineOutputT> fn,
+ Field outputField) {
+ return
aggregateFields(FieldAccessDescriptor.withFieldIds(inputFieldIds), fn,
outputField);
+ }
+
/**
* Build up an aggregation function over the input elements.
*
@@ -292,6 +339,15 @@
FieldAccessDescriptor.withFieldNames(inputFieldName), fn,
outputFieldName));
}
+ public <CombineInputT, AccumT, CombineOutputT>
CombineFieldsGlobally<InputT> aggregateField(
+ int inputFieldId,
+ CombineFn<CombineInputT, AccumT, CombineOutputT> fn,
+ String outputFieldName) {
+ return new CombineFieldsGlobally<>(
+ schemaAggregateFn.aggregateFields(
+ FieldAccessDescriptor.withFieldIds(inputFieldId), fn,
outputFieldName));
+ }
+
/**
* Build up an aggregation function over the input elements.
*
@@ -307,6 +363,13 @@
FieldAccessDescriptor.withFieldNames(inputFieldName), fn,
outputField));
}
+ public <CombineInputT, AccumT, CombineOutputT>
CombineFieldsGlobally<InputT> aggregateField(
+ int inputFieldId, CombineFn<CombineInputT, AccumT, CombineOutputT> fn,
Field outputField) {
+ return new CombineFieldsGlobally<>(
+ schemaAggregateFn.aggregateFields(
+ FieldAccessDescriptor.withFieldIds(inputFieldId), fn,
outputField));
+ }
+
/**
* Build up an aggregation function over the input elements.
*
@@ -325,6 +388,15 @@
FieldAccessDescriptor.withFieldNames(inputFieldNames), fn,
outputFieldName);
}
+ public <CombineInputT, AccumT, CombineOutputT>
+ CombineFieldsGlobally<InputT> aggregateFieldsById(
+ List<Integer> inputFieldIds,
+ CombineFn<CombineInputT, AccumT, CombineOutputT> fn,
+ String outputFieldName) {
+ return aggregateFields(
+ FieldAccessDescriptor.withFieldIds(inputFieldIds), fn,
outputFieldName);
+ }
+
/**
* Build up an aggregation function over the input elements.
*
@@ -357,6 +429,14 @@
FieldAccessDescriptor.withFieldNames(inputFieldNames), fn,
outputField);
}
+ public <CombineInputT, AccumT, CombineOutputT>
+ CombineFieldsGlobally<InputT> aggregateFieldsById(
+ List<Integer> inputFieldIds,
+ CombineFn<CombineInputT, AccumT, CombineOutputT> fn,
+ Field outputField) {
+ return
aggregateFields(FieldAccessDescriptor.withFieldIds(inputFieldIds), fn,
outputField);
+ }
+
/**
* Build up an aggregation function over the input elements.
*
@@ -428,6 +508,17 @@ Schema getKeySchema() {
FieldAccessDescriptor.withFieldNames(inputFieldName), fn,
outputFieldName));
}
+ public <CombineInputT, AccumT, CombineOutputT>
CombineFieldsByFields<InputT> aggregateField(
+ int inputFieldId,
+ CombineFn<CombineInputT, AccumT, CombineOutputT> fn,
+ String outputFieldName) {
+ return new CombineFieldsByFields<>(
+ this,
+ SchemaAggregateFn.<InputT>create()
+ .aggregateFields(
+ FieldAccessDescriptor.withFieldIds(inputFieldId), fn,
outputFieldName));
+ }
+
/**
* Build up an aggregation function over the input elements.
*
@@ -445,6 +536,14 @@ Schema getKeySchema() {
FieldAccessDescriptor.withFieldNames(inputFieldName), fn,
outputField));
}
+ public <CombineInputT, AccumT, CombineOutputT>
CombineFieldsByFields<InputT> aggregateField(
+ int inputFieldId, CombineFn<CombineInputT, AccumT, CombineOutputT> fn,
Field outputField) {
+ return new CombineFieldsByFields<>(
+ this,
+ SchemaAggregateFn.<InputT>create()
+
.aggregateFields(FieldAccessDescriptor.withFieldIds(inputFieldId), fn,
outputField));
+ }
+
/**
* Build up an aggregation function over the input elements.
*
@@ -463,6 +562,15 @@ Schema getKeySchema() {
FieldAccessDescriptor.withFieldNames(inputFieldNames), fn,
outputFieldName);
}
+ public <CombineInputT, AccumT, CombineOutputT>
+ CombineFieldsByFields<InputT> aggregateFieldsById(
+ List<Integer> inputFieldIds,
+ CombineFn<CombineInputT, AccumT, CombineOutputT> fn,
+ String outputFieldName) {
+ return aggregateFields(
+ FieldAccessDescriptor.withFieldIds(inputFieldIds), fn,
outputFieldName);
+ }
+
/**
* Build up an aggregation function over the input elements.
*
@@ -497,6 +605,14 @@ Schema getKeySchema() {
FieldAccessDescriptor.withFieldNames(inputFieldNames), fn,
outputField);
}
+ public <CombineInputT, AccumT, CombineOutputT>
+ CombineFieldsByFields<InputT> aggregateFieldsById(
+ List<Integer> inputFieldIds,
+ CombineFn<CombineInputT, AccumT, CombineOutputT> fn,
+ Field outputField) {
+ return
aggregateFields(FieldAccessDescriptor.withFieldIds(inputFieldIds), fn,
outputField);
+ }
+
/**
* Build up an aggregation function over the input elements.
*
@@ -593,6 +709,16 @@ public void process(
FieldAccessDescriptor.withFieldNames(inputFieldName), fn,
outputFieldName));
}
+ public <CombineInputT, AccumT, CombineOutputT>
CombineFieldsByFields<InputT> aggregateField(
+ int inputFieldId,
+ CombineFn<CombineInputT, AccumT, CombineOutputT> fn,
+ String outputFieldName) {
+ return new CombineFieldsByFields<>(
+ byFields,
+ schemaAggregateFn.aggregateFields(
+ FieldAccessDescriptor.withFieldIds(inputFieldId), fn,
outputFieldName));
+ }
+
/**
* Build up an aggregation function over the input elements.
*
@@ -609,6 +735,14 @@ public void process(
FieldAccessDescriptor.withFieldNames(inputFieldName), fn,
outputField));
}
+ public <CombineInputT, AccumT, CombineOutputT>
CombineFieldsByFields<InputT> aggregateField(
+ int inputFieldId, CombineFn<CombineInputT, AccumT, CombineOutputT> fn,
Field outputField) {
+ return new CombineFieldsByFields<>(
+ byFields,
+ schemaAggregateFn.aggregateFields(
+ FieldAccessDescriptor.withFieldIds(inputFieldId), fn,
outputField));
+ }
+
/**
* Build up an aggregation function over the input elements.
*
@@ -659,6 +793,14 @@ public void process(
FieldAccessDescriptor.withFieldNames(inputFieldNames), fn,
outputField);
}
+ public <CombineInputT, AccumT, CombineOutputT>
+ CombineFieldsByFields<InputT> aggregateFieldsById(
+ List<Integer> inputFieldIds,
+ CombineFn<CombineInputT, AccumT, CombineOutputT> fn,
+ Field outputField) {
+ return
aggregateFields(FieldAccessDescriptor.withFieldIds(inputFieldIds), fn,
outputField);
+ }
+
/**
* Build up an aggregation function over the input elements.
*
diff --git
a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamAggregationRel.java
b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamAggregationRel.java
index 7a7e81dc363..b35fcf84261 100644
---
a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamAggregationRel.java
+++
b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamAggregationRel.java
@@ -19,25 +19,21 @@
import static com.google.common.base.Preconditions.checkArgument;
import static java.util.stream.Collectors.toList;
-import static org.apache.beam.sdk.schemas.Schema.toSchema;
import static org.apache.beam.sdk.values.PCollection.IsBounded.BOUNDED;
-import static org.apache.beam.sdk.values.Row.toRow;
import com.google.common.collect.Lists;
+import java.io.Serializable;
import java.util.List;
import javax.annotation.Nullable;
-import org.apache.beam.sdk.coders.KvCoder;
-import
org.apache.beam.sdk.extensions.sql.impl.transform.MultipleAggregationsFn;
import
org.apache.beam.sdk.extensions.sql.impl.transform.agg.AggregationCombineFnAdapter;
import org.apache.beam.sdk.extensions.sql.impl.utils.CalciteUtils;
import org.apache.beam.sdk.schemas.Schema;
-import org.apache.beam.sdk.schemas.SchemaCoder;
-import org.apache.beam.sdk.transforms.Combine;
+import org.apache.beam.sdk.schemas.Schema.Field;
+import org.apache.beam.sdk.transforms.Combine.CombineFn;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.GroupByKey;
import org.apache.beam.sdk.transforms.PTransform;
import org.apache.beam.sdk.transforms.ParDo;
-import org.apache.beam.sdk.transforms.WithKeys;
import org.apache.beam.sdk.transforms.WithTimestamps;
import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
import org.apache.beam.sdk.transforms.windowing.DefaultTrigger;
@@ -118,58 +114,51 @@ public RelWriter explainTerms(RelWriter pw) {
@Override
public PTransform<PCollectionList<Row>, PCollection<Row>> buildPTransform() {
- Schema inputSchema = CalciteUtils.toSchema(getInput().getRowType());
Schema outputSchema = CalciteUtils.toSchema(getRowType());
- List<AggregationCombineFnAdapter> aggregationAdapters =
+ List<FieldAggregation> aggregationAdapters =
getNamedAggCalls()
.stream()
- .map(aggCall -> AggregationCombineFnAdapter.of(aggCall,
inputSchema))
+ .map(aggCall -> new FieldAggregation(aggCall.getKey(),
aggCall.getValue()))
.collect(toList());
return new Transform(
- windowFn, windowFieldIndex, inputSchema, getGroupSet(),
aggregationAdapters, outputSchema);
+ windowFn, windowFieldIndex, getGroupSet(), aggregationAdapters,
outputSchema);
+ }
+
+ private static class FieldAggregation implements Serializable {
+ final List<Integer> inputs;
+ final CombineFn combineFn;
+ final Field outputField;
+
+ FieldAggregation(AggregateCall call, String alias) {
+ inputs = call.getArgList();
+ outputField = CalciteUtils.toField(alias, call.getType());
+ combineFn =
+ AggregationCombineFnAdapter.createCombineFn(
+ call, outputField, call.getAggregation().getName());
+ }
}
private static class Transform extends PTransform<PCollectionList<Row>,
PCollection<Row>> {
private final List<Integer> keyFieldsIds;
private Schema outputSchema;
- private Schema keySchema;
- private SchemaCoder<Row> keyCoder;
private WindowFn<Row, IntervalWindow> windowFn;
private int windowFieldIndex;
- private List<AggregationCombineFnAdapter> aggregationCalls;
- private SchemaCoder<Row> aggCoder;
+ private List<FieldAggregation> fieldAggregations;
private Transform(
WindowFn<Row, IntervalWindow> windowFn,
int windowFieldIndex,
- Schema inputSchema,
ImmutableBitSet groupSet,
- List<AggregationCombineFnAdapter> aggregationCalls,
+ List<FieldAggregation> fieldAggregations,
Schema outputSchema) {
-
this.windowFn = windowFn;
this.windowFieldIndex = windowFieldIndex;
- this.aggregationCalls = aggregationCalls;
-
+ this.fieldAggregations = fieldAggregations;
this.outputSchema = outputSchema;
-
- this.keySchema =
- groupSet
- .asList()
- .stream()
- .filter(i -> i != windowFieldIndex)
- .map(inputSchema::getField)
- .collect(toSchema());
-
this.keyFieldsIds =
groupSet.asList().stream().filter(i -> i !=
windowFieldIndex).collect(toList());
-
- this.keyCoder = SchemaCoder.of(keySchema);
- this.aggCoder =
- SchemaCoder.of(
- aggregationCalls.stream().map(aggCall ->
aggCall.field()).collect(toSchema()));
}
@Override
@@ -187,13 +176,40 @@ private Transform(
validateWindowIsSupported(windowedStream);
- PCollection<KV<Row, Row>> exCombineByStream =
extractGroupingKeys(upstream, windowedStream);
-
- PCollection<KV<Row, Row>> aggregatedStream =
performAggregation(exCombineByStream);
+ org.apache.beam.sdk.schemas.transforms.Group.ByFields<Row> byFields =
+
org.apache.beam.sdk.schemas.transforms.Group.byFieldIds(keyFieldsIds);
+ org.apache.beam.sdk.schemas.transforms.Group.CombineFieldsByFields<Row>
combined = null;
+ for (FieldAggregation fieldAggregation : fieldAggregations) {
+ List<Integer> inputs = fieldAggregation.inputs;
+ CombineFn combineFn = fieldAggregation.combineFn;
+ if (inputs.size() > 1 || inputs.isEmpty()) {
+ // In this path we extract a Row (an empty row if inputs.isEmpty).
+ combined =
+ (combined == null)
+ ? byFields.aggregateFieldsById(inputs, combineFn,
fieldAggregation.outputField)
+ : combined.aggregateFieldsById(inputs, combineFn,
fieldAggregation.outputField);
+ } else {
+ // Combining over a single field, so extract just that field.
+ combined =
+ (combined == null)
+ ? byFields.aggregateField(inputs.get(0), combineFn,
fieldAggregation.outputField)
+ : combined.aggregateField(inputs.get(0), combineFn,
fieldAggregation.outputField);
+ }
+ }
- PCollection<Row> mergedStream = mergeRows(aggregatedStream);
+ PTransform<PCollection<Row>, PCollection<KV<Row, Row>>> combiner =
combined;
+ if (combiner == null) {
+ // If no field aggregations were specified, we run a constant combiner
that always returns
+ // a single empty row for each key. This is used by the SELECT
DISTINCT query plan - in this
+ // case a group by is generated to determine unique keys, and a
constant null combiner is
+ // used.
+ combiner =
byFields.aggregate(AggregationCombineFnAdapter.createConstantCombineFn());
+ }
- return mergedStream;
+ return windowedStream
+ .apply(combiner)
+ .apply("mergeRecord", ParDo.of(mergeRecord(outputSchema,
windowFieldIndex)))
+ .setRowSchema(outputSchema);
}
/** Extract timestamps from the windowFieldIndex, then window into
windowFns. */
@@ -210,33 +226,6 @@ private Transform(
return windowedStream;
}
- /** Extract non-windowing group-by fields, assign them as a key. */
- private PCollection<KV<Row, Row>> extractGroupingKeys(
- PCollection<Row> upstream, PCollection<Row> windowedStream) {
- return windowedStream
- .apply(
- "exCombineBy",
- WithKeys.of(
- row ->
keyFieldsIds.stream().map(row::getValue).collect(toRow(keySchema))))
- .setCoder(KvCoder.of(keyCoder, upstream.getCoder()));
- }
-
- private PCollection<KV<Row, Row>> performAggregation(
- PCollection<KV<Row, Row>> exCombineByStream) {
- return exCombineByStream
- .apply("combineBy",
Combine.perKey(MultipleAggregationsFn.combineFns(aggregationCalls)))
- .setCoder(KvCoder.of(keyCoder, aggCoder));
- }
-
- /** Merge the KVs back into whole rows. */
- private PCollection<Row> mergeRows(PCollection<KV<Row, Row>>
aggregatedStream) {
- PCollection<Row> mergedStream =
- aggregatedStream.apply(
- "mergeRecord", ParDo.of(mergeRecord(outputSchema,
windowFieldIndex)));
- mergedStream.setRowSchema(outputSchema);
- return mergedStream;
- }
-
/**
* Performs the same check as {@link GroupByKey}, provides more context in
exception.
*
@@ -261,14 +250,14 @@ private void validateWindowIsSupported(PCollection<Row>
upstream) {
}
static DoFn<KV<Row, Row>, Row> mergeRecord(Schema outputSchema, int
windowStartFieldIndex) {
-
return new DoFn<KV<Row, Row>, Row>() {
@ProcessElement
- public void processElement(ProcessContext c, BoundedWindow window) {
- KV<Row, Row> kvRow = c.element();
+ public void processElement(
+ @Element KV<Row, Row> kvRow, BoundedWindow window,
OutputReceiver<Row> o) {
List<Object> fieldValues =
Lists.newArrayListWithCapacity(
kvRow.getKey().getValues().size() +
kvRow.getValue().getValues().size());
+
fieldValues.addAll(kvRow.getKey().getValues());
fieldValues.addAll(kvRow.getValue().getValues());
@@ -276,7 +265,7 @@ public void processElement(ProcessContext c, BoundedWindow
window) {
fieldValues.add(windowStartFieldIndex, ((IntervalWindow)
window).start());
}
-
c.output(Row.withSchema(outputSchema).addValues(fieldValues).build());
+
o.output(Row.withSchema(outputSchema).addValues(fieldValues).build());
}
};
}
diff --git
a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/transform/MultipleAggregationsFn.java
b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/transform/MultipleAggregationsFn.java
deleted file mode 100644
index 0b8d8961410..00000000000
---
a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/transform/MultipleAggregationsFn.java
+++ /dev/null
@@ -1,175 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.beam.sdk.extensions.sql.impl.transform;
-
-import static org.apache.beam.sdk.schemas.Schema.toSchema;
-import static org.apache.beam.sdk.values.Row.toRow;
-
-import java.io.IOException;
-import java.io.InputStream;
-import java.io.OutputStream;
-import java.math.BigDecimal;
-import java.util.ArrayList;
-import java.util.List;
-import java.util.stream.Collectors;
-import java.util.stream.IntStream;
-import org.apache.beam.sdk.coders.BigDecimalCoder;
-import org.apache.beam.sdk.coders.CannotProvideCoderException;
-import org.apache.beam.sdk.coders.Coder;
-import org.apache.beam.sdk.coders.CoderException;
-import org.apache.beam.sdk.coders.CoderRegistry;
-import org.apache.beam.sdk.coders.CustomCoder;
-import org.apache.beam.sdk.coders.VarIntCoder;
-import
org.apache.beam.sdk.extensions.sql.impl.transform.agg.AggregationCombineFnAdapter;
-import org.apache.beam.sdk.schemas.Schema;
-import org.apache.beam.sdk.transforms.Combine.CombineFn;
-import org.apache.beam.sdk.values.Row;
-
-/**
- * Wrapper for multiple aggregations.
- *
- * <p>Maintains the accumulators and output schema. Delegates the aggregation
to the combiners for
- * each separate aggregation call.
- *
- * <p>Output schema is the schema corresponding to the list of all aggregation
calls.
- */
-public class MultipleAggregationsFn extends CombineFn<Row, List<Object>, Row> {
- private List<AggregationCombineFnAdapter> aggCombineFns;
- private Schema outputSchema;
-
- /**
- * Returns an instance of {@link MultipleAggregationsFn}.
- *
- * @param aggCombineFns is the list of aggregation {@link CombineFn
CombineFns} that perform the
- * actual aggregations.
- */
- public static MultipleAggregationsFn
combineFns(List<AggregationCombineFnAdapter> aggCombineFns) {
- return new MultipleAggregationsFn(aggCombineFns);
- }
-
- private MultipleAggregationsFn(List<AggregationCombineFnAdapter>
aggCombineFns) {
- this.aggCombineFns = aggCombineFns;
- outputSchema =
-
this.aggCombineFns.stream().map(AggregationCombineFnAdapter::field).collect(toSchema());
- }
-
- /**
- * Accumulator for this {@link CombineFn} is a list of accumulators of the
underlying delegate
- * {@link CombineFn CombineFns}.
- */
- @Override
- public List<Object> createAccumulator() {
- return aggCombineFns
- .stream()
- .map(AggregationCombineFnAdapter::createAccumulator)
- .collect(Collectors.toList());
- }
-
- /** For each delegate {@link CombineFn} we use the corresponding accumulator
from the list. */
- @Override
- public List<Object> addInput(List<Object> accumulators, Row input) {
- List<Object> newAcc = new ArrayList<>();
-
- for (int idx = 0; idx < aggCombineFns.size(); ++idx) {
- AggregationCombineFnAdapter aggregator = aggCombineFns.get(idx);
- Object aggregatorAccumulator = accumulators.get(idx);
-
- Object newAccumulator = aggregator.addInput(aggregatorAccumulator,
input);
- newAcc.add(newAccumulator);
- }
- return newAcc;
- }
-
- /**
- * Collect all accumulators for the corresponding {@link CombineFn} into a
list, and pass the list
- * for merging to the delegate.
- */
- @Override
- public List<Object> mergeAccumulators(Iterable<List<Object>> accumulators) {
- List<Object> newAcc = new ArrayList<>();
- for (int idx = 0; idx < aggCombineFns.size(); ++idx) {
- List accs = new ArrayList<>();
- for (List<Object> accumulator : accumulators) {
- accs.add(accumulator.get(idx));
- }
- newAcc.add(aggCombineFns.get(idx).mergeAccumulators(accs));
- }
- return newAcc;
- }
-
- /**
- * Just extract all outputs from the delegate {@link CombineFn CombineFns}
and assemble them into
- * a row.
- */
- @Override
- public Row extractOutput(List<Object> accumulator) {
- return IntStream.range(0, aggCombineFns.size())
- .mapToObj(idx ->
aggCombineFns.get(idx).extractOutput(accumulator.get(idx)))
- .collect(toRow(outputSchema));
- }
-
- /**
- * Accumulator coder is a special {@link AggregationAccumulatorCoder coder}
that encodes a list of
- * all accumulators using accumulator coders from their {@link CombineFn
CombineFns}.
- */
- @Override
- public Coder<List<Object>> getAccumulatorCoder(CoderRegistry registry,
Coder<Row> inputCoder)
- throws CannotProvideCoderException {
- // TODO: Doing this here is wrong.
- registry.registerCoderForClass(BigDecimal.class, BigDecimalCoder.of());
- List<Coder> aggAccuCoderList = new ArrayList<>();
-
- for (AggregationCombineFnAdapter aggCombineFn : aggCombineFns) {
- aggAccuCoderList.add(aggCombineFn.getAccumulatorCoder(registry,
inputCoder));
- }
-
- return new AggregationAccumulatorCoder(aggAccuCoderList);
- }
-
- /**
- * Coder for accumulators.
- *
- * <p>Takes in a list of accumulator coders, delegates encoding/decoding to
them.
- */
- static class AggregationAccumulatorCoder extends CustomCoder<List<Object>> {
- private VarIntCoder sizeCoder = VarIntCoder.of();
- private List<Coder> elementCoders;
-
- AggregationAccumulatorCoder(List<Coder> elementCoders) {
- this.elementCoders = elementCoders;
- }
-
- @Override
- public void encode(List<Object> value, OutputStream outStream) throws
IOException {
- sizeCoder.encode(value.size(), outStream);
- for (int idx = 0; idx < value.size(); ++idx) {
- elementCoders.get(idx).encode(value.get(idx), outStream);
- }
- }
-
- @Override
- public List<Object> decode(InputStream inStream) throws CoderException,
IOException {
- List<Object> accu = new ArrayList<>();
- int size = sizeCoder.decode(inStream);
- for (int idx = 0; idx < size; ++idx) {
- accu.add(elementCoders.get(idx).decode(inStream));
- }
- return accu;
- }
- }
-}
diff --git
a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/transform/agg/AggregationArgsAdapter.java
b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/transform/agg/AggregationArgsAdapter.java
deleted file mode 100644
index 775c16c8a63..00000000000
---
a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/transform/agg/AggregationArgsAdapter.java
+++ /dev/null
@@ -1,171 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.beam.sdk.extensions.sql.impl.transform.agg;
-
-import static java.util.stream.Collectors.toList;
-import static org.apache.beam.sdk.schemas.Schema.toSchema;
-
-import java.io.Serializable;
-import java.util.List;
-import javax.annotation.Nullable;
-import org.apache.beam.sdk.coders.Coder;
-import org.apache.beam.sdk.coders.RowCoder;
-import org.apache.beam.sdk.schemas.Schema;
-import org.apache.beam.sdk.schemas.SchemaCoder;
-import org.apache.beam.sdk.transforms.Combine.CombineFn;
-import org.apache.beam.sdk.values.Row;
-
-/**
- * Utility class to extract arguments from the input Row to match the expected
input of the {@link
- * CombineFn}.
- */
-class AggregationArgsAdapter {
-
- /**
- * Creates an args adapter based on the args list and input schema.
- *
- * @param argList indices of fields that are specified as arguments for an
aggregation call.
- * @param inputSchema input row that will be passed into the aggregation
- */
- static ArgsAdapter of(List<Integer> argList, Schema inputSchema) {
- if (argList.size() == 0) {
- return ZeroArgsAdapter.INSTANCE;
- } else if (argList.size() == 1) {
- return new SingleArgAdapter(inputSchema, argList);
- } else {
- return new MultiArgsAdapter(inputSchema, argList);
- }
- }
-
- /**
- * If SQL aggregation call doesn't have actual arguments, we pass an empty
row to it.
- *
- * <p>This is for cases like COUNT(1) which doesn't take any arguments, or
COUNT(*) that is a
- * special case that returns the same result. In both of these cases we
should count all Rows no
- * matter whether they have NULLs or not.
- *
- * <p>This is a special case of the MultiArgsAdapter below.
- */
- static class ZeroArgsAdapter implements ArgsAdapter {
- private static final Schema EMPTY_SCHEMA = Schema.builder().build();
- private static final Row EMPTY_ROW = Row.withSchema(EMPTY_SCHEMA).build();
- private static final Coder<Row> EMPTY_ROW_CODER =
SchemaCoder.of(EMPTY_SCHEMA);
-
- static final ZeroArgsAdapter INSTANCE = new ZeroArgsAdapter();
-
- /** Extracts the value from the first field of a row. */
- @Nullable
- @Override
- public Object getArgsValues(Row input) {
- return EMPTY_ROW;
- }
-
- /** Coder for the first field of a row. */
- @Override
- public Coder getArgsValuesCoder() {
- return EMPTY_ROW_CODER;
- }
- }
-
- /**
- * If SQL aggregation call has a single argument (e.g. MAX), we extract its
raw value to pass to
- * the delegate {@link CombineFn}.
- */
- static class SingleArgAdapter implements ArgsAdapter {
- Schema sourceSchema;
- List<Integer> argsIndicies;
-
- public SingleArgAdapter(Schema sourceSchema, List<Integer> argsIndicies) {
- this.sourceSchema = sourceSchema;
- this.argsIndicies = argsIndicies;
- }
-
- /**
- * Args indices contain a single element with the index of a field that
SQL call specifies. Here
- * we extract the value of that field from the input row.
- */
- @Nullable
- @Override
- public Object getArgsValues(Row input) {
- return input.getValue(argsIndicies.get(0));
- }
-
- /** Coder for the field of a row used as an argument. */
- @Override
- public Coder getArgsValuesCoder() {
- int fieldIndex = argsIndicies.get(0);
- return
RowCoder.coderForFieldType(sourceSchema.getField(fieldIndex).getType());
- }
- }
-
- /**
- * If SQL aggregation call has multiple arguments (e.g. COVAR_POP), we
extract the fields
- * specified in the arguments, then combine them into a row, and then pass
into the delegate
- * {@link CombineFn}.
- */
- static class MultiArgsAdapter implements ArgsAdapter {
- Schema sourceSchema;
- List<Integer> argsIndicies;
-
- MultiArgsAdapter(Schema sourceSchema, List<Integer> argsIndicies) {
- this.sourceSchema = sourceSchema;
- this.argsIndicies = argsIndicies;
- }
-
- /**
- * Extract the sub-row with the fields specified in the arguments. If args
values contain nulls,
- * return null.
- */
- @Nullable
- @Override
- public Object getArgsValues(Row input) {
- List<Object> argsValues =
argsIndicies.stream().map(input::getValue).collect(toList());
-
- if (argsValues.contains(null)) {
- return null;
- }
-
- Schema argsSchema =
- argsIndicies
- .stream()
- .map(fieldIndex -> input.getSchema().getField(fieldIndex))
- .collect(toSchema());
-
- return Row.withSchema(argsSchema).addValues(argsValues).build();
- }
-
- /** Schema coder of the sub-row specified by the fields in the arguments
list. */
- @Override
- public Coder getArgsValuesCoder() {
- Schema argsSchema =
- argsIndicies
- .stream()
- .map(fieldIndex -> sourceSchema.getField(fieldIndex))
- .collect(toSchema());
-
- return SchemaCoder.of(argsSchema);
- }
- }
-
- interface ArgsAdapter extends Serializable {
- @Nullable
- Object getArgsValues(Row input);
-
- Coder getArgsValuesCoder();
- }
-}
diff --git
a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/transform/agg/AggregationCombineFnAdapter.java
b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/transform/agg/AggregationCombineFnAdapter.java
index a58a4680b56..7737a1c2605 100644
---
a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/transform/agg/AggregationCombineFnAdapter.java
+++
b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/transform/agg/AggregationCombineFnAdapter.java
@@ -17,77 +17,146 @@
*/
package org.apache.beam.sdk.extensions.sql.impl.transform.agg;
+import javax.annotation.Nullable;
import org.apache.beam.sdk.coders.CannotProvideCoderException;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.coders.CoderRegistry;
import org.apache.beam.sdk.extensions.sql.impl.UdafImpl;
import
org.apache.beam.sdk.extensions.sql.impl.transform.BeamBuiltinAggregations;
-import
org.apache.beam.sdk.extensions.sql.impl.transform.agg.AggregationArgsAdapter.ArgsAdapter;
-import org.apache.beam.sdk.extensions.sql.impl.utils.CalciteUtils;
import org.apache.beam.sdk.schemas.Schema;
+import org.apache.beam.sdk.schemas.SchemaCoder;
import org.apache.beam.sdk.transforms.Combine.CombineFn;
import org.apache.beam.sdk.values.Row;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.sql.validate.SqlUserDefinedAggFunction;
-import org.apache.calcite.util.Pair;
-/**
- * Wrapper {@link CombineFn} for aggregation function call.
- *
- * <p>Delegates to the actual aggregation {@link CombineFn}, either built-in,
or UDAF.
- *
- * <p>Actual aggregation {@link CombineFn CombineFns} expect their specific
arguments, not the full
- * input row. This class uses {@link ArgsAdapter arg adapters} to extract and
map the call arguments
- * to the {@link CombineFn CombineFn's} inputs.
- */
-public class AggregationCombineFnAdapter extends CombineFn<Row, Object,
Object> {
+/** Wrapper {@link CombineFn}s for aggregation function calls. */
+public class AggregationCombineFnAdapter<T> {
+ private abstract static class WrappedCombinerBase<T> extends CombineFn<T,
Object, Object> {
+ CombineFn<T, Object, Object> combineFn;
+
+ WrappedCombinerBase(CombineFn<T, Object, Object> combineFn) {
+ this.combineFn = combineFn;
+ }
+
+ @Override
+ public Object createAccumulator() {
+ return combineFn.createAccumulator();
+ }
+
+ @Override
+ public Object addInput(Object accumulator, T input) {
+ T processedInput = getInput(input);
+ return (processedInput == null)
+ ? accumulator
+ : combineFn.addInput(accumulator, getInput(input));
+ }
- // Field for a function call
- private Schema.Field field;
+ @Override
+ public Object mergeAccumulators(Iterable<Object> accumulators) {
+ return combineFn.mergeAccumulators(accumulators);
+ }
- // Actual aggregation CombineFn
- private CombineFn combineFn;
+ @Override
+ public Object extractOutput(Object accumulator) {
+ return combineFn.extractOutput(accumulator);
+ }
- // Adapter to convert input Row to CombineFn's arguments
- private ArgsAdapter argsAdapter;
+ @Nullable
+ abstract T getInput(T input);
- /** {@link Schema.Field} with this function call. */
- public Schema.Field field() {
- return field;
+ @Override
+ public Coder<Object> getAccumulatorCoder(CoderRegistry registry, Coder<T>
inputCoder)
+ throws CannotProvideCoderException {
+ return combineFn.getAccumulatorCoder(registry, inputCoder);
+ }
}
- private AggregationCombineFnAdapter(
- Schema.Field field, CombineFn combineFn, ArgsAdapter argsAdapter) {
- this.field = field;
- this.combineFn = combineFn;
- this.argsAdapter = argsAdapter;
+ private static class MultiInputCombiner extends WrappedCombinerBase<Row> {
+ MultiInputCombiner(CombineFn<Row, Object, Object> combineFn) {
+ super(combineFn);
+ }
+
+ @Override
+ Row getInput(Row input) {
+ for (Object o : input.getValues()) {
+ if (o == null) {
+ return null;
+ }
+ }
+ return input;
+ }
}
- /**
- * Creates an instance of {@link AggregationCombineFnAdapter}.
- *
- * @param callWithAlias Calcite's output, represents a function call paired
with its field alias
- */
- public static AggregationCombineFnAdapter of(
- Pair<AggregateCall, String> callWithAlias, Schema inputSchema) {
- AggregateCall call = callWithAlias.getKey();
- Schema.Field field = CalciteUtils.toField(callWithAlias.getValue(),
call.getType());
- String functionName = call.getAggregation().getName();
-
- return new AggregationCombineFnAdapter(
- field,
- createCombineFn(call, field, functionName),
- AggregationArgsAdapter.of(call.getArgList(), inputSchema));
+ private static class SingleInputCombiner extends WrappedCombinerBase<Object>
{
+ SingleInputCombiner(CombineFn<Object, Object, Object> combineFn) {
+ super(combineFn);
+ }
+
+ @Override
+ Object getInput(Object input) {
+ return input;
+ }
+ }
+
+ private static class ConstantEmpty extends CombineFn<Row, Row, Row> {
+ private static final Schema EMPTY_SCHEMA = Schema.builder().build();
+ private static final Row EMPTY_ROW = Row.withSchema(EMPTY_SCHEMA).build();
+
+ public static final ConstantEmpty INSTANCE = new ConstantEmpty();
+
+ @Override
+ public Row createAccumulator() {
+ return EMPTY_ROW;
+ }
+
+ @Override
+ public Row addInput(Row accumulator, Row input) {
+ return EMPTY_ROW;
+ }
+
+ @Override
+ public Row mergeAccumulators(Iterable<Row> accumulators) {
+ return EMPTY_ROW;
+ }
+
+ @Override
+ public Row extractOutput(Row accumulator) {
+ return EMPTY_ROW;
+ }
+
+ @Override
+ public Coder<Row> getAccumulatorCoder(CoderRegistry registry, Coder<Row>
inputCoder)
+ throws CannotProvideCoderException {
+ return SchemaCoder.of(EMPTY_SCHEMA);
+ }
+
+ @Override
+ public Coder<Row> getDefaultOutputCoder(CoderRegistry registry, Coder<Row>
inputCoder) {
+ return SchemaCoder.of(EMPTY_SCHEMA);
+ }
}
/** Creates either a UDAF or a built-in {@link CombineFn}. */
- private static CombineFn<?, ?, ?> createCombineFn(
+ public static CombineFn<?, ?, ?> createCombineFn(
AggregateCall call, Schema.Field field, String functionName) {
+ CombineFn combineFn;
if (call.getAggregation() instanceof SqlUserDefinedAggFunction) {
- return getUdafCombineFn(call);
+ combineFn = getUdafCombineFn(call);
+ } else {
+ combineFn = BeamBuiltinAggregations.create(functionName,
field.getType().getTypeName());
+ }
+ if (call.getArgList().isEmpty()) {
+ return new SingleInputCombiner(combineFn);
+ } else if (call.getArgList().size() == 1) {
+ return new SingleInputCombiner(combineFn);
+ } else {
+ return new MultiInputCombiner(combineFn);
}
+ }
- return BeamBuiltinAggregations.create(functionName,
field.getType().getTypeName());
+ public static CombineFn<Row, ?, Row> createConstantCombineFn() {
+ return ConstantEmpty.INSTANCE;
}
private static CombineFn<?, ?, ?> getUdafCombineFn(AggregateCall call) {
@@ -98,44 +167,4 @@ public static AggregationCombineFnAdapter of(
throw new IllegalStateException(e);
}
}
-
- @Override
- public Object createAccumulator() {
- return combineFn.createAccumulator();
- }
-
- /**
- * Calls the args adapter to extract the fields from the input row and pass
them into the actual
- * {@link CombineFn}. E.g. input of a MAX(f) is not a full row, but just a
number.
- *
- * <p>If argument is null, skip it and return the original accumulator. This
is what SQL
- * aggregations are supposed to do.
- */
- @Override
- public Object addInput(Object accumulator, Row input) {
- Object argsValues = argsAdapter.getArgsValues(input);
- return (argsValues == null) ? accumulator :
combineFn.addInput(accumulator, argsValues);
- }
-
- @Override
- public Object mergeAccumulators(Iterable<Object> accumulators) {
- return combineFn.mergeAccumulators(accumulators);
- }
-
- @Override
- public Object extractOutput(Object accumulator) {
- return combineFn.extractOutput(accumulator);
- }
-
- /**
- * {@link CombineFn#getAccumulatorCoder} is supposed to use input {@link
Coder coder} to infer the
- * {@link Coder coder} for the accumulator. Here we call the args adapter to
get the input coder
- * for the delegate {@link CombineFn}.
- */
- @Override
- public Coder<Object> getAccumulatorCoder(CoderRegistry registry, Coder<Row>
inputCoder)
- throws CannotProvideCoderException {
-
- return combineFn.getAccumulatorCoder(registry,
argsAdapter.getArgsValuesCoder());
- }
}
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
Issue Time Tracking
-------------------
Worklog Id: (was: 160793)
Time Spent: 16h 20m (was: 16h 10m)
> Create a library of useful transforms that use schemas
> ------------------------------------------------------
>
> Key: BEAM-4461
> URL: https://issues.apache.org/jira/browse/BEAM-4461
> Project: Beam
> Issue Type: Sub-task
> Components: sdk-java-core
> Reporter: Reuven Lax
> Assignee: Reuven Lax
> Priority: Major
> Time Spent: 16h 20m
> Remaining Estimate: 0h
>
> e.g. JoinBy(fields). Project, Filter, etc.
--
This message was sent by Atlassian JIRA
(v7.6.3#76005)