This is an automated email from the ASF dual-hosted git repository.

tvalentyn pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new 66a8a4c  Adding performance improvements to ApproximateQuantiles. 
(#13175)
66a8a4c is described below

commit 66a8a4c5e97f674f594c2194da736ad63a842dfe
Author: Ihor Indyk <ihor.in...@gmail.com>
AuthorDate: Fri Feb 19 01:50:04 2021 -0500

    Adding performance improvements to ApproximateQuantiles. (#13175)
---
 sdks/python/apache_beam/transforms/stats.pxd     |  60 ++
 sdks/python/apache_beam/transforms/stats.py      | 719 ++++++++++++++---------
 sdks/python/apache_beam/transforms/stats_test.py | 114 +++-
 sdks/python/setup.py                             |   1 +
 4 files changed, 610 insertions(+), 284 deletions(-)

diff --git a/sdks/python/apache_beam/transforms/stats.pxd 
b/sdks/python/apache_beam/transforms/stats.pxd
new file mode 100644
index 0000000..e67c012
--- /dev/null
+++ b/sdks/python/apache_beam/transforms/stats.pxd
@@ -0,0 +1,60 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+cimport cython
+from libc.stdint cimport int64_t
+
+cdef class _QuantileSpec(object):
+  cdef readonly int64_t buffer_size
+  cdef readonly int64_t num_buffers
+  cdef readonly bint weighted
+  cdef readonly key
+  cdef readonly bint reverse
+  cdef readonly weighted_key
+  cdef readonly less_than
+
+cdef class _QuantileBuffer(object):
+  cdef readonly elements
+  cdef readonly weights
+  cdef readonly bint weighted
+  cdef readonly int64_t level
+  cdef readonly min_val
+  cdef readonly max_val
+  cdef readonly _iter
+
+cdef class _QuantileState(object):
+  cdef readonly _QuantileSpec spec
+  cdef public buffers
+  cdef public unbuffered_elements
+  cdef public unbuffered_weights
+  cdef public add_unbuffered
+  cpdef bint is_empty(self)
+  @cython.locals(num_new_buffers=int64_t, idx=int64_t)
+  cpdef _add_unbuffered(self, elements, offset_fn)
+  @cython.locals(num_new_buffers=int64_t, idx=int64_t)
+  cpdef _add_unbuffered_weighted(self, elements, offset_fn)
+  cpdef finalize(self)
+  @cython.locals(min_level=int64_t)
+  cpdef collapse_if_needed(self, offset_fn)
+
+
+@cython.locals(new_level=int64_t, new_weight=double, step=double, 
offset=double)
+cdef _QuantileBuffer _collapse(buffers, offset_fn, _QuantileSpec spec)
+
+@cython.locals(j=int64_t)
+cdef _interpolate(buffers, int64_t count, double step, double offset,
+                  _QuantileSpec spec)
\ No newline at end of file
diff --git a/sdks/python/apache_beam/transforms/stats.py 
b/sdks/python/apache_beam/transforms/stats.py
index 9d19bc4..bd6dc72 100644
--- a/sdks/python/apache_beam/transforms/stats.py
+++ b/sdks/python/apache_beam/transforms/stats.py
@@ -36,14 +36,12 @@ import heapq
 import itertools
 import logging
 import math
-import sys
 import typing
 from builtins import round
 from typing import Any
-from typing import Generic
-from typing import Iterable
+from typing import Callable
 from typing import List
-from typing import Sequence
+from typing import Tuple
 
 from apache_beam import coders
 from apache_beam import typehints
@@ -61,30 +59,34 @@ T = typing.TypeVar('T')
 K = typing.TypeVar('K')
 V = typing.TypeVar('V')
 
+try:
+  import mmh3  # pylint: disable=import-error
 
-def _get_default_hash_fn():
-  """Returns either murmurhash or md5 based on installation."""
-  try:
-    import mmh3  # pylint: disable=import-error
+  def _mmh3_hash(value):
+    # mmh3.hash64 returns two 64-bit unsigned integers
+    return mmh3.hash64(value, seed=0, signed=False)[0]
+
+  _default_hash_fn = _mmh3_hash
+  _default_hash_fn_type = 'mmh3'
+except ImportError:
+
+  def _md5_hash(value):
+    # md5 is a 128-bit hash, so we truncate the hexdigest (string of 32
+    # hexadecimal digits) to 16 digits and convert to int to get the 64-bit
+    # integer fingerprint.
+    return int(hashlib.md5(value).hexdigest()[:16], 16)
 
-    def _mmh3_hash(value):
-      # mmh3.hash64 returns two 64-bit unsigned integers
-      return mmh3.hash64(value, seed=0, signed=False)[0]
+  _default_hash_fn = _md5_hash
+  _default_hash_fn_type = 'md5'
 
-    return _mmh3_hash
 
-  except ImportError:
+def _get_default_hash_fn():
+  """Returns either murmurhash or md5 based on installation."""
+  if _default_hash_fn_type == 'md5':
     logging.warning(
         'Couldn\'t find murmurhash. Install mmh3 for a faster implementation 
of'
         'ApproximateUnique.')
-
-    def _md5_hash(value):
-      # md5 is a 128-bit hash, so we truncate the hexdigest (string of 32
-      # hexadecimal digits) to 16 digits and convert to int to get the 64-bit
-      # integer fingerprint.
-      return int(hashlib.md5(value).hexdigest()[:16], 16)
-
-    return _md5_hash
+  return _default_hash_fn
 
 
 class ApproximateUnique(object):
@@ -297,9 +299,20 @@ class ApproximateQuantiles(object):
       weighted=True
 
     out: [0, 2, 5, 7, 100]
+
+    in: [list(range(10)), ..., list(range(90, 101))], num_quantiles=5,
+      input_batched=True
+
+    out: [0, 25, 50, 75, 100]
+
+    in: [(list(range(10)), [1]*10), (list(range(10)), [0]*10), ...,
+      (list(range(90, 101)), [0]*11)], num_quantiles=5, input_batched=True,
+      weighted=True
+
+    out: [0, 2, 5, 7, 100]
   """
   @staticmethod
-  def _display_data(num_quantiles, key, reverse, weighted):
+  def _display_data(num_quantiles, key, reverse, weighted, input_batched):
     return {
         'num_quantiles': DisplayDataItem(num_quantiles, label='Quantile 
Count'),
         'key': DisplayDataItem(
@@ -307,7 +320,9 @@ class ApproximateQuantiles(object):
             if hasattr(key, '__name__') else key.__class__.__name__,
             label='Record Comparer Key'),
         'reverse': DisplayDataItem(str(reverse), label='Is Reversed'),
-        'weighted': DisplayDataItem(str(weighted), label='Is Weighted')
+        'weighted': DisplayDataItem(str(weighted), label='Is Weighted'),
+        'input_batched': DisplayDataItem(
+            str(input_batched), label='Is Input Batched'),
     }
 
   @typehints.with_input_types(
@@ -327,12 +342,24 @@ class ApproximateQuantiles(object):
       weighted: (optional) if set to True, the transform returns weighted
         quantiles. The input PCollection is then expected to contain tuples of
         input values with the corresponding weight.
+      input_batched: (optional) if set to True, the transform expects each
+        element of input PCollection to be a batch, which is a list of elements
+        for non-weighted case and a tuple of lists of elements and weights for
+        weighted. Provides a way to accumulate multiple elements at a time more
+        efficiently.
     """
-    def __init__(self, num_quantiles, key=None, reverse=False, weighted=False):
+    def __init__(
+        self,
+        num_quantiles,
+        key=None,
+        reverse=False,
+        weighted=False,
+        input_batched=False):
       self._num_quantiles = num_quantiles
       self._key = key
       self._reverse = reverse
       self._weighted = weighted
+      self._input_batched = input_batched
 
     def expand(self, pcoll):
       return pcoll | CombineGlobally(
@@ -340,14 +367,16 @@ class ApproximateQuantiles(object):
               num_quantiles=self._num_quantiles,
               key=self._key,
               reverse=self._reverse,
-              weighted=self._weighted))
+              weighted=self._weighted,
+              input_batched=self._input_batched))
 
     def display_data(self):
       return ApproximateQuantiles._display_data(
           num_quantiles=self._num_quantiles,
           key=self._key,
           reverse=self._reverse,
-          weighted=self._weighted)
+          weighted=self._weighted,
+          input_batched=self._input_batched)
 
   @typehints.with_input_types(
       typehints.Union[typing.Tuple[K, V],
@@ -368,12 +397,24 @@ class ApproximateQuantiles(object):
       weighted: (optional) if set to True, the transform returns weighted
         quantiles. The input PCollection is then expected to contain tuples of
         input values with the corresponding weight.
+      input_batched: (optional) if set to True, the transform expects each
+        element of input PCollection to be a batch, which is a list of elements
+        for non-weighted case and a tuple of lists of elements and weights for
+        weighted. Provides a way to accumulate multiple elements at a time more
+        efficiently.
     """
-    def __init__(self, num_quantiles, key=None, reverse=False, weighted=False):
+    def __init__(
+        self,
+        num_quantiles,
+        key=None,
+        reverse=False,
+        weighted=False,
+        input_batched=False):
       self._num_quantiles = num_quantiles
       self._key = key
       self._reverse = reverse
       self._weighted = weighted
+      self._input_batched = input_batched
 
     def expand(self, pcoll):
       return pcoll | CombinePerKey(
@@ -381,69 +422,106 @@ class ApproximateQuantiles(object):
               num_quantiles=self._num_quantiles,
               key=self._key,
               reverse=self._reverse,
-              weighted=self._weighted))
+              weighted=self._weighted,
+              input_batched=self._input_batched))
 
     def display_data(self):
       return ApproximateQuantiles._display_data(
           num_quantiles=self._num_quantiles,
           key=self._key,
           reverse=self._reverse,
-          weighted=self._weighted)
+          weighted=self._weighted,
+          input_batched=self._input_batched)
 
 
-class _QuantileBuffer(Generic[T]):
+class _QuantileSpec(object):
+  """Quantiles computation specifications."""
+  def __init__(self, buffer_size, num_buffers, weighted, key, reverse):
+    # type: (int, int, bool, Any, bool) -> None
+    self.buffer_size = buffer_size
+    self.num_buffers = num_buffers
+    self.weighted = weighted
+    self.key = key
+    self.reverse = reverse
+
+    # Used to sort tuples of values and weights.
+    self.weighted_key = None if key is None else (lambda x: key(x[0]))
+
+    # Used to compare values.
+    if reverse and key is None:
+      self.less_than = lambda a, b: a > b
+    elif reverse:
+      self.less_than = lambda a, b: key(a) > key(b)
+    elif key is None:
+      self.less_than = lambda a, b: a < b
+    else:
+      self.less_than = lambda a, b: key(a) < key(b)
+
+  def get_argsort_key(self, elements):
+    # type: (List) -> Callable[[int], Any]
+
+    """Returns a key for sorting indices of elements by element's value."""
+    if self.key is None:
+      return elements.__getitem__
+    else:
+      return lambda idx: self.key(elements[idx])
+
+  def __reduce__(self):
+    return (
+        self.__class__,
+        (
+            self.buffer_size,
+            self.num_buffers,
+            self.weighted,
+            self.key,
+            self.reverse))
+
+
+class _QuantileBuffer(object):
   """A single buffer in the sense of the referenced algorithm.
   (see http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.6.6513&rep=rep1
   &type=pdf and ApproximateQuantilesCombineFn for further information)"""
-  def __init__(self, elements, weighted, level=0, weight=1):
-    # type: (Sequence[T], bool, int, int) -> None
-    # In case of weighted quantiles, elements are tuples of values and weights.
+  def __init__(
+      self, elements, weights, weighted, level=0, min_val=None, max_val=None):
+    # type: (List, List, bool, int, Any, Any) -> None
     self.elements = elements
+    # In non-weighted case weights contains a single element representing 
weight
+    # of the buffer in the sense of the original algorithm. In weighted case,
+    # it stores weights of individual elements.
+    self.weights = weights
     self.weighted = weighted
     self.level = level
-    self.weight = weight
-
-  def __lt__(self, other):
-    if self.weighted:
-      return [element[0] for element in self.elements
-              ] < [element[0] for element in other.elements]
+    if min_val is None or max_val is None:
+      # Buffer is always initialized with sorted elements.
+      self.min_val = elements[0]
+      self.max_val = elements[-1]
     else:
-      return self.elements < other.elements
-
-  def sized_iterator(self):
-    class QuantileBufferIterator(object):
-      def __init__(self, elem, weighted, weight):
-        self._iter = iter(elem)
-        self.weighted = weighted
-        self.weight = weight
-
-      def __iter__(self):
-        return self
+      # Note that collapsed buffer may not contain min and max in the list of
+      # elements.
+      self.min_val = min_val
+      self.max_val = max_val
 
-      def __next__(self):
-        if self.weighted:
-          return next(self._iter)
-        else:
-          value = next(self._iter)
-          return (value, self.weight)
+  def __iter__(self):
+    return zip(
+        self.elements,
+        self.weights if self.weighted else itertools.repeat(self.weights[0]))
 
-      next = __next__  # For Python 2
-
-    return QuantileBufferIterator(self.elements, self.weighted, self.weight)
+  def __lt__(self, other):
+    return self.level < other.level
 
 
-class _QuantileState(Generic[T]):
+class _QuantileState(object):
   """
   Compact summarization of a collection on which quantiles can be estimated.
   """
-  min_val = None  # type: Any  # Holds smallest item in the list
-  max_val = None  # type: Any  # Holds largest item in the list
-
-  def __init__(self, buffer_size, num_buffers, unbuffered_elements, buffers):
-    # type: (int, int, List[Any], List[_QuantileBuffer[T]]) -> None
-    self.buffer_size = buffer_size
-    self.num_buffers = num_buffers
+  def __init__(self, unbuffered_elements, unbuffered_weights, buffers, spec):
+    # type: (List, List, List[_QuantileBuffer], _QuantileSpec) -> None
     self.buffers = buffers
+    self.spec = spec
+    if spec.weighted:
+      self.add_unbuffered = self._add_unbuffered_weighted
+    else:
+      self.add_unbuffered = self._add_unbuffered
 
     # The algorithm requires that the manipulated buffers always be filled to
     # capacity to perform the collapse operation. This operation can be 
extended
@@ -452,6 +530,17 @@ class _QuantileState(Generic[T]):
     # into new, full buffers and then take them into account when computing the
     # final output.
     self.unbuffered_elements = unbuffered_elements
+    self.unbuffered_weights = unbuffered_weights
+
+  # This is needed for pickling to work when Cythonization is enabled.
+  def __reduce__(self):
+    return (
+        self.__class__,
+        (
+            self.unbuffered_elements,
+            self.unbuffered_weights,
+            self.buffers,
+            self.spec))
 
   def is_empty(self):
     # type: () -> bool
@@ -459,8 +548,219 @@ class _QuantileState(Generic[T]):
     """Check if the buffered & unbuffered elements are empty or not."""
     return not self.unbuffered_elements and not self.buffers
 
+  def _add_unbuffered(self, elements, offset_fn):
+    # type: (List, Any) -> None
+
+    """
+    Add elements to the unbuffered list, creating new buffers and
+    collapsing if needed.
+    """
+    self.unbuffered_elements.extend(elements)
+    num_new_buffers = len(self.unbuffered_elements) // self.spec.buffer_size
+    for idx in range(num_new_buffers):
+      to_buffer = sorted(
+          self.unbuffered_elements[idx * self.spec.buffer_size:(idx + 1) *
+                                   self.spec.buffer_size],
+          key=self.spec.key,
+          reverse=self.spec.reverse)
+      heapq.heappush(
+          self.buffers,
+          _QuantileBuffer(elements=to_buffer, weights=[1], weighted=False))
+
+    if num_new_buffers > 0:
+      self.unbuffered_elements = self.unbuffered_elements[num_new_buffers *
+                                                          self.spec.
+                                                          buffer_size:]
+
+    self.collapse_if_needed(offset_fn)
+
+  def _add_unbuffered_weighted(self, elements, offset_fn):
+    # type: (List, Any) -> None
+
+    """
+    Add elements with weights to the unbuffered list, creating new buffers and
+    collapsing if needed.
+    """
+    if len(elements) == 1:
+      self.unbuffered_elements.append(elements[0][0])
+      self.unbuffered_weights.append(elements[0][1])
+    else:
+      self.unbuffered_elements.extend(elements[0])
+      self.unbuffered_weights.extend(elements[1])
+    num_new_buffers = len(self.unbuffered_elements) // self.spec.buffer_size
+    argsort_key = self.spec.get_argsort_key(self.unbuffered_elements)
+    for idx in range(num_new_buffers):
+      argsort = sorted(
+          range(idx * self.spec.buffer_size, (idx + 1) * 
self.spec.buffer_size),
+          key=argsort_key,
+          reverse=self.spec.reverse)
+      elements_to_buffer = [self.unbuffered_elements[idx] for idx in argsort]
+      weights_to_buffer = [self.unbuffered_weights[idx] for idx in argsort]
+      heapq.heappush(
+          self.buffers,
+          _QuantileBuffer(
+              elements=elements_to_buffer,
+              weights=weights_to_buffer,
+              weighted=True))
+
+    if num_new_buffers > 0:
+      self.unbuffered_elements = self.unbuffered_elements[num_new_buffers *
+                                                          self.spec.
+                                                          buffer_size:]
+      self.unbuffered_weights = self.unbuffered_weights[num_new_buffers *
+                                                        self.spec.buffer_size:]
+
+    self.collapse_if_needed(offset_fn)
+
+  def finalize(self):
+    # type: () -> None
+
+    """
+    Creates a new buffer using all unbuffered elements. Called before
+    extracting an output. Note that the buffer doesn't have to be put in a
+    proper position since _collapse is not going to be called after.
+    """
+    if self.unbuffered_elements and self.spec.weighted:
+      argsort_key = self.spec.get_argsort_key(self.unbuffered_elements)
+      argsort = sorted(
+          range(len(self.unbuffered_elements)),
+          key=argsort_key,
+          reverse=self.spec.reverse)
+      self.unbuffered_elements = [
+          self.unbuffered_elements[idx] for idx in argsort
+      ]
+      self.unbuffered_weights = [
+          self.unbuffered_weights[idx] for idx in argsort
+      ]
+      self.buffers.append(
+          _QuantileBuffer(
+              self.unbuffered_elements, self.unbuffered_weights, 
weighted=True))
+      self.unbuffered_weights = []
+    elif self.unbuffered_elements:
+      self.unbuffered_elements.sort(
+          key=self.spec.key, reverse=self.spec.reverse)
+      self.buffers.append(
+          _QuantileBuffer(
+              self.unbuffered_elements, weights=[1], weighted=False))
+    self.unbuffered_elements = []
+
+  def collapse_if_needed(self, offset_fn):
+    # type: (Any) -> None
+
+    """
+    Checks if summary has too many buffers and collapses some of them until the
+    limit is restored.
+    """
+    while len(self.buffers) > self.spec.num_buffers:
+      to_collapse = [heapq.heappop(self.buffers), heapq.heappop(self.buffers)]
+      min_level = to_collapse[1].level
+
+      while self.buffers and self.buffers[0].level <= min_level:
+        to_collapse.append(heapq.heappop(self.buffers))
+
+      heapq.heappush(self.buffers, _collapse(to_collapse, offset_fn, 
self.spec))
+
+
+def _collapse(buffers, offset_fn, spec):
+  # type: (List[_QuantileBuffer], Any, _QuantileSpec) -> _QuantileBuffer
+
+  """
+  Approximates elements from multiple buffers and produces a single buffer.
+  """
+  new_level = 0
+  new_weight = 0
+  for buffer in buffers:
+    # As presented in the paper, there should always be at least two
+    # buffers of the same (minimal) level to collapse, but it is possible
+    # to violate this condition when combining buffers from independently
+    # computed shards. If they differ we take the max.
+    new_level = max([new_level, buffer.level + 1])
+    new_weight = new_weight + sum(buffer.weights)
+  if spec.weighted:
+    step = new_weight / (spec.buffer_size - 1)
+    offset = new_weight / (2 * spec.buffer_size)
+  else:
+    step = new_weight
+    offset = offset_fn(new_weight)
+  new_elements, new_weights, min_val, max_val = \
+      _interpolate(buffers, spec.buffer_size, step, offset, spec)
+  if not spec.weighted:
+    new_weights = [new_weight]
+  return _QuantileBuffer(
+      new_elements, new_weights, spec.weighted, new_level, min_val, max_val)
+
+
+def _interpolate(buffers, count, step, offset, spec):
+  # type: (List[_QuantileBuffer], int, float, float, _QuantileSpec) -> 
Tuple[List, List, Any, Any]
+
+  """
+  Emulates taking the ordered union of all elements in buffers, repeated
+  according to their weight, and picking out the (k * step + offset)-th 
elements
+  of this list for `0 <= k < count`.
+  """
+  buffer_iterators = []
+  min_val = buffers[0].min_val
+  max_val = buffers[0].max_val
+  for buffer in buffers:
+    # Calculate extreme values for the union of buffers.
+    min_val = buffer.min_val if spec.less_than(
+        buffer.min_val, min_val) else min_val
+    max_val = buffer.max_val if spec.less_than(
+        max_val, buffer.max_val) else max_val
+    buffer_iterators.append(iter(buffer))
+
+  # Note that `heapq.merge` can also be used here since the buffers are sorted.
+  # In practice, however, `sorted` uses natural order in the union and
+  # significantly outperforms `heapq.merge`.
+  sorted_elements = sorted(
+      itertools.chain.from_iterable(buffer_iterators),
+      key=spec.weighted_key,
+      reverse=spec.reverse)
+
+  if not spec.weighted:
+    # If all buffers have the same weight, then quantiles' indices are evenly
+    # distributed over a range [0, len(sorted_elements)].
+    buffers_have_same_weight = True
+    weight = buffers[0].weights[0]
+    for buffer in buffers:
+      if buffer.weights[0] != weight:
+        buffers_have_same_weight = False
+        break
+    if buffers_have_same_weight:
+      offset = offset / weight
+      step = step / weight
+      max_idx = len(sorted_elements) - 1
+      result = [
+          sorted_elements[min(int(j * step + offset), max_idx)][0]
+          for j in range(count)
+      ]
+      return result, [], min_val, max_val
+
+  sorted_elements_iter = iter(sorted_elements)
+  weighted_element = next(sorted_elements_iter)
+  new_elements = []
+  new_weights = []
+  j = 0
+  current_weight = weighted_element[1]
+  previous_weight = 0
+  while j < count:
+    target_weight = j * step + offset
+    j += 1
+    try:
+      while current_weight <= target_weight:
+        weighted_element = next(sorted_elements_iter)
+        current_weight += weighted_element[1]
+    except StopIteration:
+      pass
+    new_elements.append(weighted_element[0])
+    if spec.weighted:
+      new_weights.append(current_weight - previous_weight)
+      previous_weight = current_weight
+
+  return new_elements, new_weights, min_val, max_val
+
 
-class ApproximateQuantilesCombineFn(CombineFn, Generic[T]):
+class ApproximateQuantilesCombineFn(CombineFn):
   """
   This combiner gives an idea of the distribution of a collection of values
   using approximate N-tiles. The output of this combiner is the list of size of
@@ -483,9 +783,12 @@ class ApproximateQuantilesCombineFn(CombineFn, Generic[T]):
   http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.6.6513&rep=rep1
   &type=pdf
 
-  The default error bound is (1 / N) for uniformly distributed data and
-  min(1e-2, 1 / N) for weighted case, though in practice the accuracy tends to
-  be much better.
+  Note that the weighted quantiles are evaluated using a generalized version of
+  the algorithm referenced in the paper.
+
+  The default error bound is (1 / num_quantiles) for uniformly distributed data
+  and min(1e-2, 1 / num_quantiles) for weighted case, though in practice the
+  accuracy tends to be much better.
 
   Args:
     num_quantiles: Number of quantiles to produce. It is the size of the final
@@ -501,6 +804,8 @@ class ApproximateQuantilesCombineFn(CombineFn, Generic[T]):
     weighted: (optional) if set to True, the combiner produces weighted
       quantiles. The input elements are then expected to be tuples of input
       values with the corresponding weight.
+    input_batched: (optional) if set to True, inputs are expected to be batches
+      of elements.
   """
 
   # For alternating between biasing up and down in the above even weight
@@ -514,7 +819,7 @@ class ApproximateQuantilesCombineFn(CombineFn, Generic[T]):
   # non-optimal. The impact is logarithmic with respect to this value, so this
   # default should be fine for most uses.
   _MAX_NUM_ELEMENTS = 1e9
-  _qs = None  # type: _QuantileState[T]
+  _qs = None  # type: _QuantileState
 
   def __init__(
       self,
@@ -523,29 +828,25 @@ class ApproximateQuantilesCombineFn(CombineFn, 
Generic[T]):
       num_buffers,  # type: int
       key=None,
       reverse=False,
-      weighted=False):
-    def _comparator(a, b):
-      if key:
-        a, b = key(a), key(b)
-
-      retval = int(a > b) - int(a < b)
-
-      if reverse:
-        return -retval
-
-      return retval
-
-    self._comparator = _comparator
-
+      weighted=False,
+      input_batched=False):
     self._num_quantiles = num_quantiles
-    self._buffer_size = buffer_size
-    self._num_buffers = num_buffers
-    if weighted:
-      self._key = (lambda x: x[0]) if key is None else (lambda x: key(x[0]))
-    else:
-      self._key = key
-    self._reverse = reverse
-    self._weighted = weighted
+    self._spec = _QuantileSpec(buffer_size, num_buffers, weighted, key, 
reverse)
+    self._input_batched = input_batched
+    if self._input_batched:
+      setattr(self, 'add_input', self._add_inputs)
+
+  def __reduce__(self):
+    return (
+        self.__class__,
+        (
+            self._num_quantiles,
+            self._spec.buffer_size,
+            self._spec.num_buffers,
+            self._spec.key,
+            self._spec.reverse,
+            self._spec.weighted,
+            self._input_batched))
 
   @classmethod
   def create(
@@ -555,7 +856,8 @@ class ApproximateQuantilesCombineFn(CombineFn, Generic[T]):
       max_num_elements=None,
       key=None,
       reverse=False,
-      weighted=False):
+      weighted=False,
+      input_batched=False):
     # type: (...) -> ApproximateQuantilesCombineFn
 
     """
@@ -582,11 +884,17 @@ class ApproximateQuantilesCombineFn(CombineFn, 
Generic[T]):
       weighted: (optional) if set to True, the combiner produces weighted
         quantiles. The input elements are then expected to be tuples of values
         with the corresponding weight.
+      input_batched: (optional) if set to True, inputs are expected to be
+        batches of elements.
     """
     max_num_elements = max_num_elements or cls._MAX_NUM_ELEMENTS
     if not epsilon:
       epsilon = min(1e-2, 1.0 / num_quantiles) \
         if weighted else (1.0 / num_quantiles)
+    # Note that calculation of the buffer size and the number of buffers here
+    # is based on technique used in the Munro-Paterson algorithm. Switching to
+    # the logic used in the "New Algorithm" may result in memory savings since
+    # it results in lower values for b and k in practice.
     b = 2
     while (b - 2) * (1 << (b - 2)) < epsilon * max_num_elements:
       b = b + 1
@@ -598,30 +906,8 @@ class ApproximateQuantilesCombineFn(CombineFn, Generic[T]):
         num_buffers=b,
         key=key,
         reverse=reverse,
-        weighted=weighted)
-
-  def _add_unbuffered(self, qs, elements):
-    # type: (_QuantileState[T], Iterable[T]) -> None
-
-    """
-    Add a new buffer to the unbuffered list, creating a new buffer and
-    collapsing if needed.
-    """
-    qs.unbuffered_elements.extend(elements)
-    if len(qs.unbuffered_elements) >= qs.buffer_size:
-      qs.unbuffered_elements.sort(key=self._key, reverse=self._reverse)
-
-      while len(qs.unbuffered_elements) >= qs.buffer_size:
-        to_buffer = qs.unbuffered_elements[:qs.buffer_size]
-        heapq.heappush(
-            qs.buffers,
-            _QuantileBuffer(
-                elements=to_buffer,
-                weighted=self._weighted,
-                weight=sum([element[1] for element in to_buffer])
-                if self._weighted else 1))
-        qs.unbuffered_elements = qs.unbuffered_elements[qs.buffer_size:]
-        self._collapse_if_needed(qs)
+        weighted=weighted,
+        input_batched=input_batched)
 
   def _offset(self, new_weight):
     # type: (int) -> float
@@ -636,132 +922,33 @@ class ApproximateQuantilesCombineFn(CombineFn, 
Generic[T]):
       self._offset_jitter = 2 - self._offset_jitter
       return (new_weight + self._offset_jitter) / 2
 
-  def _collapse(self, buffers):
-    # type: (Iterable[_QuantileBuffer[T]]) -> _QuantileBuffer[T]
-    new_level = 0
-    new_weight = 0
-    for buffer_elem in buffers:
-      # As presented in the paper, there should always be at least two
-      # buffers of the same (minimal) level to collapse, but it is possible
-      # to violate this condition when combining buffers from independently
-      # computed shards.  If they differ we take the max.
-      new_level = max([new_level, buffer_elem.level + 1])
-      new_weight = new_weight + buffer_elem.weight
-    if self._weighted:
-      step = new_weight / (self._buffer_size - 1)
-      offset = new_weight / (2 * self._buffer_size)
-    else:
-      step = new_weight
-      offset = self._offset(new_weight)
-    new_elements = self._interpolate(buffers, self._buffer_size, step, offset)
-    return _QuantileBuffer(new_elements, self._weighted, new_level, new_weight)
-
-  def _collapse_if_needed(self, qs):
-    # type: (_QuantileState) -> None
-    while len(qs.buffers) > self._num_buffers:
-      to_collapse = []
-      to_collapse.append(heapq.heappop(qs.buffers))
-      to_collapse.append(heapq.heappop(qs.buffers))
-      min_level = to_collapse[1].level
-
-      while len(qs.buffers) > 0 and qs.buffers[0].level == min_level:
-        to_collapse.append(heapq.heappop(qs.buffers))
-
-      heapq.heappush(qs.buffers, self._collapse(to_collapse))
-
-  def _interpolate(self, i_buffers, count, step, offset):
-    """
-    Emulates taking the ordered union of all elements in buffers, repeated
-    according to their weight, and picking out the (k * step + offset)-th
-    elements of this list for `0 <= k < count`.
-    """
-
-    iterators = []
-    new_elements = []
-    compare_key = self._key
-    if self._key and not self._weighted:
-      compare_key = lambda x: self._key(x[0])
-    for buffer_elem in i_buffers:
-      iterators.append(buffer_elem.sized_iterator())
-
-    # Python 3 `heapq.merge` support key comparison and returns an iterator and
-    # does not pull the data into memory all at once. Python 2 does not
-    # support comparison on its `heapq.merge` api, so we use the itertools
-    # which takes the `key` function for comparison and creates an iterator
-    # from it.
-    if sys.version_info[0] < 3:
-      sorted_elem = iter(
-          sorted(
-              itertools.chain.from_iterable(iterators),
-              key=compare_key,
-              reverse=self._reverse))
-    else:
-      sorted_elem = heapq.merge(
-          *iterators, key=compare_key, reverse=self._reverse)
-
-    weighted_element = next(sorted_elem)
-    current = weighted_element[1]
-    j = 0
-    previous = 0
-    while j < count:
-      target = j * step + offset
-      j = j + 1
-      try:
-        while current <= target:
-          weighted_element = next(sorted_elem)
-          current = current + weighted_element[1]
-      except StopIteration:
-        pass
-      if self._weighted:
-        new_elements.append((weighted_element[0], current - previous))
-        previous = current
-      else:
-        new_elements.append(weighted_element[0])
-    return new_elements
-
   # TODO(BEAM-7746): Signature incompatible with supertype
   def create_accumulator(self):  # type: ignore[override]
-    # type: () -> _QuantileState[T]
+    # type: () -> _QuantileState
     self._qs = _QuantileState(
-        buffer_size=self._buffer_size,
-        num_buffers=self._num_buffers,
         unbuffered_elements=[],
-        buffers=[])
+        unbuffered_weights=[],
+        buffers=[],
+        spec=self._spec)
     return self._qs
 
   def add_input(self, quantile_state, element):
     """
     Add a new element to the collection being summarized by quantile state.
     """
-    value = element[0] if self._weighted else element
-    if quantile_state.is_empty():
-      quantile_state.min_val = quantile_state.max_val = value
-    elif self._comparator(value, quantile_state.min_val) < 0:
-      quantile_state.min_val = value
-    elif self._comparator(value, quantile_state.max_val) > 0:
-      quantile_state.max_val = value
-    self._add_unbuffered(quantile_state, elements=[element])
+    quantile_state.add_unbuffered([element], self._offset)
     return quantile_state
 
-  def add_inputs(self, quantile_state, elements):
-    """Add new elements to the collection being summarized by quantile state.
+  def _add_inputs(self, quantile_state, elements):
+    # type: (_QuantileState, List) -> _QuantileState
+
     """
-    if not elements:
+    Add a batch of elements to the collection being summarized by quantile
+    state.
+    """
+    if len(elements) == 0:
       return quantile_state
-
-    values = [
-        element[0] for element in elements
-    ] if self._weighted else elements
-    min_val = min(values)
-    max_val = max(values)
-    if quantile_state.is_empty():
-      quantile_state.min_val = min_val
-      quantile_state.max_val = max_val
-    elif self._comparator(min_val, quantile_state.min_val) < 0:
-      quantile_state.min_val = min_val
-    elif self._comparator(max_val, quantile_state.max_val) > 0:
-      quantile_state.max_val = max_val
-    self._add_unbuffered(quantile_state, elements=elements)
+    quantile_state.add_unbuffered(elements, self._offset)
     return quantile_state
 
   def merge_accumulators(self, accumulators):
@@ -770,17 +957,16 @@ class ApproximateQuantilesCombineFn(CombineFn, 
Generic[T]):
     for accumulator in accumulators:
       if accumulator.is_empty():
         continue
-      if not qs.min_val or self._comparator(accumulator.min_val,
-                                            qs.min_val) < 0:
-        qs.min_val = accumulator.min_val
-      if not qs.max_val or self._comparator(accumulator.max_val,
-                                            qs.max_val) > 0:
-        qs.max_val = accumulator.max_val
-
-      self._add_unbuffered(qs, accumulator.unbuffered_elements)
+      if self._spec.weighted:
+        qs.add_unbuffered(
+            [accumulator.unbuffered_elements, accumulator.unbuffered_weights],
+            self._offset)
+      else:
+        qs.add_unbuffered(accumulator.unbuffered_elements, self._offset)
 
       qs.buffers.extend(accumulator.buffers)
-    self._collapse_if_needed(qs)
+    heapq.heapify(qs.buffers)
+    qs.collapse_if_needed(self._offset)
     return qs
 
   def extract_output(self, accumulator):
@@ -791,46 +977,21 @@ class ApproximateQuantilesCombineFn(CombineFn, 
Generic[T]):
     """
     if accumulator.is_empty():
       return []
-
+    accumulator.finalize()
     all_elems = accumulator.buffers
-    if self._weighted:
-      unbuffered_weight = sum(
-          [element[1] for element in accumulator.unbuffered_elements])
-      total_weight = unbuffered_weight
+    total_weight = 0
+    if self._spec.weighted:
       for buffer_elem in all_elems:
-        total_weight += sum([element[1] for element in buffer_elem.elements])
-      if accumulator.unbuffered_elements:
-        accumulator.unbuffered_elements.sort(
-            key=self._key, reverse=self._reverse)
-        all_elems.append(
-            _QuantileBuffer(
-                accumulator.unbuffered_elements,
-                weighted=True,
-                weight=unbuffered_weight))
-
-      step = 1.0 * total_weight / (self._num_quantiles - 1)
-      offset = (1.0 * total_weight) / (self._num_quantiles - 1)
-      mid_quantiles = [
-          element[0] for element in self._interpolate(
-              all_elems, self._num_quantiles - 2, step, offset)
-      ]
+        total_weight += sum(buffer_elem.weights)
     else:
-      total_weight = len(accumulator.unbuffered_elements)
       for buffer_elem in all_elems:
-        total_weight += accumulator.buffer_size * buffer_elem.weight
-
-      if accumulator.unbuffered_elements:
-        accumulator.unbuffered_elements.sort(
-            key=self._key, reverse=self._reverse)
-        all_elems.append(
-            _QuantileBuffer(accumulator.unbuffered_elements, weighted=False))
-
-      step = 1.0 * total_weight / (self._num_quantiles - 1)
-      offset = (1.0 * total_weight - 1) / (self._num_quantiles - 1)
-      mid_quantiles = self._interpolate(
-          all_elems, self._num_quantiles - 2, step, offset)
-
-    quantiles = [accumulator.min_val]
-    quantiles.extend(mid_quantiles)
-    quantiles.append(accumulator.max_val)
-    return quantiles
+        total_weight += len(buffer_elem.elements) * buffer_elem.weights[0]
+
+    step = total_weight / (self._num_quantiles - 1)
+    offset = (total_weight - 1) / (self._num_quantiles - 1)
+
+    quantiles, _, min_val, max_val = \
+        _interpolate(all_elems, self._num_quantiles - 2, step, offset,
+                     self._spec)
+
+    return [min_val] + quantiles + [max_val]
diff --git a/sdks/python/apache_beam/transforms/stats_test.py 
b/sdks/python/apache_beam/transforms/stats_test.py
index 860594f..1cd8c8f 100644
--- a/sdks/python/apache_beam/transforms/stats_test.py
+++ b/sdks/python/apache_beam/transforms/stats_test.py
@@ -482,13 +482,116 @@ class ApproximateQuantilesTest(unittest.TestCase):
           equal_to([["ccccc", "aaa", "b"]]),
           label='checkWithKeyAndReversed')
 
+  def test_batched_quantiles(self):
+    with TestPipeline() as p:
+      data = []
+      for i in range(100):
+        data.append([(j / 10, abs(j - 500))
+                     for j in range(i * 10, (i + 1) * 10)])
+      pc = p | Create(data)
+      globally = (
+          pc | 'Globally' >> beam.ApproximateQuantiles.Globally(
+              3, input_batched=True))
+      with_key = (
+          pc | 'Globally with key' >> beam.ApproximateQuantiles.Globally(
+              3, key=sum, input_batched=True))
+      key_with_reversed = (
+          pc | 'Globally with key and reversed' >>
+          beam.ApproximateQuantiles.Globally(
+              3, key=sum, reverse=True, input_batched=True))
+      assert_that(
+          globally,
+          equal_to([[(0.0, 500), (49.9, 1), (99.9, 499)]]),
+          label='checkGlobally')
+      assert_that(
+          with_key,
+          equal_to([[(50.0, 0), (72.5, 225), (99.9, 499)]]),
+          label='checkGloballyWithKey')
+      assert_that(
+          key_with_reversed,
+          equal_to([[(99.9, 499), (72.5, 225), (50.0, 0)]]),
+          label='checkGloballyWithKeyAndReversed')
+
+  def test_batched_weighted_quantiles(self):
+    with TestPipeline() as p:
+      data = []
+      for i in range(100):
+        data.append([[(i / 10, abs(i - 500))
+                      for i in range(i * 10, (i + 1) * 10)], [i] * 10])
+      pc = p | Create(data)
+      globally = (
+          pc | 'Globally' >> beam.ApproximateQuantiles.Globally(
+              3, weighted=True, input_batched=True))
+      with_key = (
+          pc | 'Globally with key' >> beam.ApproximateQuantiles.Globally(
+              3, key=sum, weighted=True, input_batched=True))
+      key_with_reversed = (
+          pc | 'Globally with key and reversed' >>
+          beam.ApproximateQuantiles.Globally(
+              3, key=sum, reverse=True, weighted=True, input_batched=True))
+      assert_that(
+          globally,
+          equal_to([[(0.0, 500), (70.8, 208), (99.9, 499)]]),
+          label='checkGlobally')
+      assert_that(
+          with_key,
+          equal_to([[(50.0, 0), (21.0, 290), (99.9, 499)]]),
+          label='checkGloballyWithKey')
+      assert_that(
+          key_with_reversed,
+          equal_to([[(99.9, 499), (21.0, 290), (50.0, 0)]]),
+          label='checkGloballyWithKeyAndReversed')
+
+  def test_quantiles_merge_accumulators(self):
+    # This test exercises merging multiple buffers and approximation accuracy.
+    # The max_num_elements is set to a small value to trigger buffers collapse
+    # and interpolation. Under the conditions below, buffer_size=125 and
+    # num_buffers=4, so we're only allowed to keep half of the input values.
+    num_accumulators = 100
+    num_quantiles = 5
+    eps = 0.01
+    max_num_elements = 1000
+    combine_fn = ApproximateQuantilesCombineFn.create(
+        num_quantiles, eps, max_num_elements)
+    combine_fn_weighted = ApproximateQuantilesCombineFn.create(
+        num_quantiles, eps, max_num_elements, weighted=True)
+    data = list(range(1000))
+    weights = list(reversed(range(1000)))
+    step = math.ceil(len(data) / num_accumulators)
+    accumulators = []
+    accumulators_weighted = []
+    for i in range(num_accumulators):
+      accumulator = combine_fn.create_accumulator()
+      accumulator_weighted = combine_fn_weighted.create_accumulator()
+      for element, weight in zip(data[i*step:(i+1)*step],
+                                 weights[i*step:(i+1)*step]):
+        accumulator = combine_fn.add_input(accumulator, element)
+        accumulator_weighted = combine_fn_weighted.add_input(
+            accumulator_weighted, (element, weight))
+      accumulators.append(accumulator)
+      accumulators_weighted.append(accumulator_weighted)
+    accumulator = combine_fn.merge_accumulators(accumulators)
+    accumulator_weighted = combine_fn_weighted.merge_accumulators(
+        accumulators_weighted)
+    quantiles = combine_fn.extract_output(accumulator)
+    quantiles_weighted = combine_fn_weighted.extract_output(
+        accumulator_weighted)
+
+    # In fact, the final accuracy is much higher than eps, but we test for a
+    # minimal accuracy here.
+    for q, actual_q in zip(quantiles, [0, 249, 499, 749, 999]):
+      self.assertAlmostEqual(q, actual_q, delta=max_num_elements * eps)
+    for q, actual_q in zip(quantiles_weighted, [0, 133, 292, 499, 999]):
+      self.assertAlmostEqual(q, actual_q, delta=max_num_elements * eps)
+
   @staticmethod
   def _display_data_matcher(instance):
     expected_items = [
         DisplayDataItemMatcher('num_quantiles', instance._num_quantiles),
         DisplayDataItemMatcher('weighted', str(instance._weighted)),
         DisplayDataItemMatcher('key', str(instance._key.__name__)),
-        DisplayDataItemMatcher('reverse', str(instance._reverse))
+        DisplayDataItemMatcher('reverse', str(instance._reverse)),
+        DisplayDataItemMatcher('input_batched', str(instance._input_batched)),
     ]
     return expected_items
 
@@ -551,8 +654,9 @@ class ApproximateQuantilesBufferTest(unittest.TestCase):
     combine_fn = ApproximateQuantilesCombineFn.create(
         num_quantiles=10, max_num_elements=maxInputSize, epsilon=epsilon)
     self.assertEqual(
-        expectedNumBuffers, combine_fn._num_buffers, "Number of buffers")
-    self.assertEqual(expectedBufferSize, combine_fn._buffer_size, "Buffer 
size")
+        expectedNumBuffers, combine_fn._spec.num_buffers, "Number of buffers")
+    self.assertEqual(
+        expectedBufferSize, combine_fn._spec.buffer_size, "Buffer size")
 
   @parameterized.expand(_build_quantilebuffer_test_data)
   def test_correctness(self, epsilon, maxInputSize, *args):
@@ -561,8 +665,8 @@ class ApproximateQuantilesBufferTest(unittest.TestCase):
     """
     combine_fn = ApproximateQuantilesCombineFn.create(
         num_quantiles=10, max_num_elements=maxInputSize, epsilon=epsilon)
-    b = combine_fn._num_buffers
-    k = combine_fn._buffer_size
+    b = combine_fn._spec.num_buffers
+    k = combine_fn._spec.buffer_size
     n = maxInputSize
     self.assertLessEqual((b - 2) * (1 << (b - 2)) + 0.5, (epsilon * n),
                          '(b-2)2^(b-2) + 1/2 <= eN')
diff --git a/sdks/python/setup.py b/sdks/python/setup.py
index fdc9c5b..624beb3 100644
--- a/sdks/python/setup.py
+++ b/sdks/python/setup.py
@@ -277,6 +277,7 @@ setuptools.setup(
         'apache_beam/runners/worker/opcounters.py',
         'apache_beam/runners/worker/operations.py',
         'apache_beam/transforms/cy_combiners.py',
+        'apache_beam/transforms/stats.py',
         'apache_beam/utils/counters.py',
         'apache_beam/utils/windowed_value.py',
     ]),

Reply via email to