Add an element batching transform.
Project: http://git-wip-us.apache.org/repos/asf/beam/repo Commit: http://git-wip-us.apache.org/repos/asf/beam/commit/d226c767 Tree: http://git-wip-us.apache.org/repos/asf/beam/tree/d226c767 Diff: http://git-wip-us.apache.org/repos/asf/beam/diff/d226c767 Branch: refs/heads/master Commit: d226c7679b9d94a40553609f31ecbfba72559e8a Parents: 3dc7559 Author: Robert Bradshaw <rober...@gmail.com> Authored: Mon Oct 9 16:46:19 2017 -0700 Committer: Robert Bradshaw <rober...@gmail.com> Committed: Fri Oct 13 17:13:41 2017 -0700 ---------------------------------------------------------------------- sdks/python/apache_beam/transforms/util.py | 260 +++++++++++++++++++ sdks/python/apache_beam/transforms/util_test.py | 108 ++++++++ 2 files changed, 368 insertions(+) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/beam/blob/d226c767/sdks/python/apache_beam/transforms/util.py ---------------------------------------------------------------------- diff --git a/sdks/python/apache_beam/transforms/util.py b/sdks/python/apache_beam/transforms/util.py index 647781f..85d4975 100644 --- a/sdks/python/apache_beam/transforms/util.py +++ b/sdks/python/apache_beam/transforms/util.py @@ -20,14 +20,25 @@ from __future__ import absolute_import +import collections +import contextlib +import time + +from apache_beam import typehints +from apache_beam.metrics import Metrics +from apache_beam.transforms import window from apache_beam.transforms.core import CombinePerKey +from apache_beam.transforms.core import DoFn from apache_beam.transforms.core import Flatten from apache_beam.transforms.core import GroupByKey from apache_beam.transforms.core import Map +from apache_beam.transforms.core import ParDo from apache_beam.transforms.ptransform import PTransform from apache_beam.transforms.ptransform import ptransform_fn +from apache_beam.utils import windowed_value __all__ = [ + 'BatchElements', 'CoGroupByKey', 'Keys', 'KvSwap', @@ -36,6 +47,9 @@ __all__ = [ ] +T = typehints.TypeVariable('T') + + class CoGroupByKey(PTransform): """Groups results across several PCollections by key. @@ -163,3 +177,249 @@ def RemoveDuplicates(pcoll): # pylint: disable=invalid-name | 'ToPairs' >> Map(lambda v: (v, None)) | 'Group' >> CombinePerKey(lambda vs: None) | 'RemoveDuplicates' >> Keys()) + + +class _BatchSizeEstimator(object): + """Estimates the best size for batches given historical timing. + """ + + _MAX_DATA_POINTS = 100 + _MAX_GROWTH_FACTOR = 2 + + def __init__(self, + min_batch_size=1, + max_batch_size=1000, + target_batch_overhead=.1, + target_batch_duration_secs=1, + clock=time.time): + if min_batch_size > max_batch_size: + raise ValueError("Minimum (%s) must not be greater than maximum (%s)" % ( + min_batch_size, max_batch_size)) + if target_batch_overhead and not 0 < target_batch_overhead <= 1: + raise ValueError("target_batch_overhead (%s) must be between 0 and 1" % ( + target_batch_overhead)) + if target_batch_duration_secs and target_batch_duration_secs <= 0: + raise ValueError("target_batch_duration_secs (%s) must be positive" % ( + target_batch_duration_secs)) + if max(0, target_batch_overhead, target_batch_duration_secs) == 0: + raise ValueError("At least one of target_batch_overhead or " + "target_batch_duration_secs must be positive.") + self._min_batch_size = min_batch_size + self._max_batch_size = max_batch_size + self._target_batch_overhead = target_batch_overhead + self._target_batch_duration_secs = target_batch_duration_secs + self._clock = clock + self._data = [] + self._ignore_next_timing = False + self._size_distribution = Metrics.distribution( + 'BatchElements', 'batch_size') + self._time_distribution = Metrics.distribution( + 'BatchElements', 'msec_per_batch') + # Beam distributions only accept integer values, so we use this to + # accumulate under-reported values until they add up to whole milliseconds. + # (Milliseconds are chosen because that's conventionally used elsewhere in + # profiling-style counters.) + self._remainder_msecs = 0 + + def ignore_next_timing(self): + """Call to indicate the next timing should be ignored. + + For example, the first emit of a ParDo operation is known to be anomalous + due to setup that may occur. + """ + self._ignore_next_timing = False + + @contextlib.contextmanager + def record_time(self, batch_size): + start = self._clock() + yield + elapsed = self._clock() - start + elapsed_msec = 1e3 * elapsed + self._remainder_msecs + self._size_distribution.update(batch_size) + self._time_distribution.update(int(elapsed_msec)) + self._remainder_msecs = elapsed_msec - int(elapsed_msec) + if self._ignore_next_timing: + self._ignore_next_timing = False + else: + self._data.append((batch_size, elapsed)) + if len(self._data) >= self._MAX_DATA_POINTS: + self._thin_data() + + def _thin_data(self): + sorted_data = sorted(self._data) + odd_one_out = [sorted_data[-1]] if len(sorted_data) % 2 == 1 else [] + # Sort the pairs by how different they are. + pairs = sorted(zip(sorted_data[::2], sorted_data[1::2]), + key=lambda ((x1, _1), (x2, _2)): x2 / x1) + # Keep the top 1/3 most different pairs, average the top 2/3 most similar. + threshold = 2 * len(pairs) / 3 + self._data = ( + list(sum(pairs[threshold:], ())) + + [((x1 + x2) / 2.0, (t1 + t2) / 2.0) + for (x1, t1), (x2, t2) in pairs[:threshold]] + + odd_one_out) + + def next_batch_size(self): + if self._min_batch_size == self._max_batch_size: + return self._min_batch_size + elif len(self._data) < 1: + return self._min_batch_size + elif len(self._data) < 2: + # Force some variety so we have distinct batch sizes on which to do + # linear regression below. + return int(max( + min(self._max_batch_size, + self._min_batch_size * self._MAX_GROWTH_FACTOR), + self._min_batch_size + 1)) + + # Linear regression for y = a + bx, where x is batch size and y is time. + xs, ys = zip(*self._data) + n = float(len(self._data)) + xbar = sum(xs) / n + ybar = sum(ys) / n + b = (sum([(x - xbar) * (y - ybar) for x, y in self._data]) + / sum([(x - xbar)**2 for x in xs])) + a = ybar - b * xbar + + # Avoid nonsensical or division-by-zero errors below due to noise. + a = max(a, 1e-10) + b = max(b, 1e-20) + + last_batch_size = self._data[-1][0] + cap = min(last_batch_size * self._MAX_GROWTH_FACTOR, self._max_batch_size) + + if self._target_batch_duration_secs: + # Solution to a + b*x = self._target_batch_duration_secs. + cap = min(cap, (self._target_batch_duration_secs - a) / b) + + if self._target_batch_overhead: + # Solution to a / (a + b*x) = self._target_batch_overhead. + cap = min(cap, (a / b) * (1 / self._target_batch_overhead - 1)) + + # Avoid getting stuck at min_batch_size. + jitter = len(self._data) % 2 + return int(max(self._min_batch_size + jitter, cap)) + + +class _GlobalWindowsBatchingDoFn(DoFn): + def __init__(self, batch_size_estimator): + self._batch_size_estimator = batch_size_estimator + + def start_bundle(self): + self._batch = [] + self._batch_size = self._batch_size_estimator.next_batch_size() + # The first emit often involves non-trivial setup. + self._batch_size_estimator.ignore_next_timing() + + def process(self, element): + self._batch.append(element) + if len(self._batch) >= self._batch_size: + with self._batch_size_estimator.record_time(self._batch_size): + yield self._batch + self._batch = [] + self._batch_size = self._batch_size_estimator.next_batch_size() + + def finish_bundle(self): + if self._batch: + with self._batch_size_estimator.record_time(self._batch_size): + yield window.GlobalWindows.windowed_value(self._batch) + self._batch = None + self._batch_size = self._batch_size_estimator.next_batch_size() + + +class _WindowAwareBatchingDoFn(DoFn): + + _MAX_LIVE_WINDOWS = 10 + + def __init__(self, batch_size_estimator): + self._batch_size_estimator = batch_size_estimator + + def start_bundle(self): + self._batches = collections.defaultdict(list) + self._batch_size = self._batch_size_estimator.next_batch_size() + # The first emit often involves non-trivial setup. + self._batch_size_estimator.ignore_next_timing() + + def process(self, element, window=DoFn.WindowParam): + self._batches[window].append(element) + if len(self._batches[window]) >= self._batch_size: + with self._batch_size_estimator.record_time(self._batch_size): + yield windowed_value.WindowedValue( + self._batches[window], window.max_timestamp(), (window,)) + del self._batches[window] + self._batch_size = self._batch_size_estimator.next_batch_size() + elif len(self._batches) > self._MAX_LIVE_WINDOWS: + window, _ = sorted( + self._batches.items(), + key=lambda window_batch: len(window_batch[1]), + reverse=True)[0] + with self._batch_size_estimator.record_time(self._batch_size): + yield windowed_value.WindowedValue( + self._batches[window], window.max_timestamp(), (window,)) + del self._batches[window] + self._batch_size = self._batch_size_estimator.next_batch_size() + + def finish_bundle(self): + for window, batch in self._batches.items(): + if batch: + with self._batch_size_estimator.record_time(self._batch_size): + yield windowed_value.WindowedValue( + batch, window.max_timestamp(), (window,)) + self._batches = None + self._batch_size = self._batch_size_estimator.next_batch_size() + + +@typehints.with_input_types(T) +@typehints.with_output_types(typehints.List[T]) +class BatchElements(PTransform): + """A Transform that batches elements for amortized processing. + + This transform is designed to precede operations whose processing cost + is of the form + + time = fixed_cost + num_elements * per_element_cost + + where the per element cost is (often significantly) smaller than the fixed + cost and could be amortized over multiple elements. It consumes a PCollection + of element type T and produces a PCollection of element type List[T]. + + This transform attempts to find the best batch size between the minimim + and maximum parameters by profiling the time taken by (fused) downstream + operations. For a fixed batch size, set the min and max to be equal. + + Elements are batched per-window and batches emitted in the window + corresponding to its contents. + + Args: + min_batch_size: (optional) the smallest number of elements per batch + max_batch_size: (optional) the largest number of elements per batch + target_batch_overhead: (optional) a target for fixed_cost / time, + as used in the formula above + target_batch_duration_secs: (optional) a target for total time per bundle, + in seconds + clock: (optional) an alternative to time.time for measuring the cost of + donwstream operations (mostly for testing) + """ + def __init__(self, + min_batch_size=1, + max_batch_size=1000, + target_batch_overhead=.05, + target_batch_duration_secs=1, + clock=time.time): + self._batch_size_estimator = _BatchSizeEstimator( + min_batch_size=min_batch_size, + max_batch_size=max_batch_size, + target_batch_overhead=target_batch_overhead, + target_batch_duration_secs=target_batch_duration_secs, + clock=clock) + + def expand(self, pcoll): + if getattr(pcoll.pipeline.runner, 'is_streaming', False): + raise NotImplementedError("Requires stateful processing (BEAM-2687)") + elif pcoll.windowing.is_default(): + # This is the same logic as _GlobalWindowsBatchingDoFn, but optimized + # for that simpler case. + return pcoll | ParDo(_GlobalWindowsBatchingDoFn( + self._batch_size_estimator)) + else: + return pcoll | ParDo(_WindowAwareBatchingDoFn(self._batch_size_estimator)) http://git-wip-us.apache.org/repos/asf/beam/blob/d226c767/sdks/python/apache_beam/transforms/util_test.py ---------------------------------------------------------------------- diff --git a/sdks/python/apache_beam/transforms/util_test.py b/sdks/python/apache_beam/transforms/util_test.py new file mode 100644 index 0000000..6064e2c --- /dev/null +++ b/sdks/python/apache_beam/transforms/util_test.py @@ -0,0 +1,108 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""Unit tests for the transform.util classes.""" + +import time +import unittest + +import apache_beam as beam +from apache_beam.testing.test_pipeline import TestPipeline +from apache_beam.testing.util import assert_that +from apache_beam.testing.util import equal_to +from apache_beam.transforms import util +from apache_beam.transforms import window + + +class FakeClock(object): + + def __init__(self): + self._now = time.time() + + def __call__(self): + return self._now + + def sleep(self, duration): + self._now += duration + + +class BatchElementsTest(unittest.TestCase): + + def test_constant_batch(self): + # Assumes a single bundle... + with TestPipeline() as p: + res = ( + p + | beam.Create(range(35)) + | util.BatchElements(min_batch_size=10, max_batch_size=10) + | beam.Map(len)) + assert_that(res, equal_to([10, 10, 10, 5])) + + def test_grows_to_max_batch(self): + # Assumes a single bundle... + with TestPipeline() as p: + res = ( + p + | beam.Create(range(164)) + | util.BatchElements( + min_batch_size=1, max_batch_size=50, clock=FakeClock()) + | beam.Map(len)) + assert_that(res, equal_to([1, 1, 2, 4, 8, 16, 32, 50, 50])) + + def test_windowed_batches(self): + # Assumes a single bundle, in order... + with TestPipeline() as p: + res = ( + p + | beam.Create(range(47)) + | beam.Map(lambda t: window.TimestampedValue(t, t)) + | beam.WindowInto(window.FixedWindows(30)) + | util.BatchElements( + min_batch_size=5, max_batch_size=10, clock=FakeClock()) + | beam.Map(len)) + assert_that(res, equal_to([ + 5, 5, 10, 10, # elements in [0, 30) + 10, 7, # elements in [30, 47) + ])) + + def test_target_duration(self): + clock = FakeClock() + batch_estimator = util._BatchSizeEstimator( + target_batch_overhead=None, target_batch_duration_secs=10, clock=clock) + batch_duration = lambda batch_size: 1 + .7 * batch_size + # 1 + 12 * .7 is as close as we can get to 10 as possible. + expected_sizes = [1, 2, 4, 8, 12, 12, 12] + actual_sizes = [] + for _ in range(len(expected_sizes)): + actual_sizes.append(batch_estimator.next_batch_size()) + with batch_estimator.record_time(actual_sizes[-1]): + clock.sleep(batch_duration(actual_sizes[-1])) + self.assertEqual(expected_sizes, actual_sizes) + + def test_target_overhead(self): + clock = FakeClock() + batch_estimator = util._BatchSizeEstimator( + target_batch_overhead=.05, target_batch_duration_secs=None, clock=clock) + batch_duration = lambda batch_size: 1 + .7 * batch_size + # At 27 items, a batch takes ~20 seconds with 5% (~1 second) overhead. + expected_sizes = [1, 2, 4, 8, 16, 27, 27, 27] + actual_sizes = [] + for _ in range(len(expected_sizes)): + actual_sizes.append(batch_estimator.next_batch_size()) + with batch_estimator.record_time(actual_sizes[-1]): + clock.sleep(batch_duration(actual_sizes[-1])) + self.assertEqual(expected_sizes, actual_sizes)